test_llm.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. import json
  2. import os
  3. from unittest.mock import MagicMock
  4. import pytest
  5. from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
  6. from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
  7. from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration
  8. from core.model_manager import ModelInstance
  9. from core.model_runtime.entities.model_entities import ModelType
  10. from core.model_runtime.model_providers import ModelProviderFactory
  11. from core.workflow.entities.variable_pool import VariablePool
  12. from core.workflow.enums import SystemVariableKey
  13. from core.workflow.nodes.base_node import UserFrom
  14. from core.workflow.nodes.llm.llm_node import LLMNode
  15. from extensions.ext_database import db
  16. from models.provider import ProviderType
  17. from models.workflow import WorkflowNodeExecutionStatus
  18. """FOR MOCK FIXTURES, DO NOT REMOVE"""
  19. from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
  20. from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
  21. @pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
  22. def test_execute_llm(setup_openai_mock):
  23. node = LLMNode(
  24. tenant_id="1",
  25. app_id="1",
  26. workflow_id="1",
  27. user_id="1",
  28. invoke_from=InvokeFrom.WEB_APP,
  29. user_from=UserFrom.ACCOUNT,
  30. config={
  31. "id": "llm",
  32. "data": {
  33. "title": "123",
  34. "type": "llm",
  35. "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
  36. "prompt_template": [
  37. {"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."},
  38. {"role": "user", "text": "{{#sys.query#}}"},
  39. ],
  40. "memory": None,
  41. "context": {"enabled": False},
  42. "vision": {"enabled": False},
  43. },
  44. },
  45. )
  46. # construct variable pool
  47. pool = VariablePool(
  48. system_variables={
  49. SystemVariableKey.QUERY: "what's the weather today?",
  50. SystemVariableKey.FILES: [],
  51. SystemVariableKey.CONVERSATION_ID: "abababa",
  52. SystemVariableKey.USER_ID: "aaa",
  53. },
  54. user_inputs={},
  55. environment_variables=[],
  56. )
  57. pool.add(["abc", "output"], "sunny")
  58. credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
  59. provider_instance = ModelProviderFactory().get_provider_instance("openai")
  60. model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
  61. provider_model_bundle = ProviderModelBundle(
  62. configuration=ProviderConfiguration(
  63. tenant_id="1",
  64. provider=provider_instance.get_provider_schema(),
  65. preferred_provider_type=ProviderType.CUSTOM,
  66. using_provider_type=ProviderType.CUSTOM,
  67. system_configuration=SystemConfiguration(enabled=False),
  68. custom_configuration=CustomConfiguration(provider=CustomProviderConfiguration(credentials=credentials)),
  69. model_settings=[],
  70. ),
  71. provider_instance=provider_instance,
  72. model_type_instance=model_type_instance,
  73. )
  74. model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model="gpt-3.5-turbo")
  75. model_config = ModelConfigWithCredentialsEntity(
  76. model="gpt-3.5-turbo",
  77. provider="openai",
  78. mode="chat",
  79. credentials=credentials,
  80. parameters={},
  81. model_schema=model_type_instance.get_model_schema("gpt-3.5-turbo"),
  82. provider_model_bundle=provider_model_bundle,
  83. )
  84. # Mock db.session.close()
  85. db.session.close = MagicMock()
  86. node._fetch_model_config = MagicMock(return_value=(model_instance, model_config))
  87. # execute node
  88. result = node.run(pool)
  89. assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
  90. assert result.outputs["text"] is not None
  91. assert result.outputs["usage"]["total_tokens"] > 0
  92. @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
  93. @pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
  94. def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
  95. """
  96. Test execute LLM node with jinja2
  97. """
  98. node = LLMNode(
  99. tenant_id="1",
  100. app_id="1",
  101. workflow_id="1",
  102. user_id="1",
  103. invoke_from=InvokeFrom.WEB_APP,
  104. user_from=UserFrom.ACCOUNT,
  105. config={
  106. "id": "llm",
  107. "data": {
  108. "title": "123",
  109. "type": "llm",
  110. "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
  111. "prompt_config": {
  112. "jinja2_variables": [
  113. {"variable": "sys_query", "value_selector": ["sys", "query"]},
  114. {"variable": "output", "value_selector": ["abc", "output"]},
  115. ]
  116. },
  117. "prompt_template": [
  118. {
  119. "role": "system",
  120. "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}",
  121. "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.",
  122. "edition_type": "jinja2",
  123. },
  124. {
  125. "role": "user",
  126. "text": "{{#sys.query#}}",
  127. "jinja2_text": "{{sys_query}}",
  128. "edition_type": "basic",
  129. },
  130. ],
  131. "memory": None,
  132. "context": {"enabled": False},
  133. "vision": {"enabled": False},
  134. },
  135. },
  136. )
  137. # construct variable pool
  138. pool = VariablePool(
  139. system_variables={
  140. SystemVariableKey.QUERY: "what's the weather today?",
  141. SystemVariableKey.FILES: [],
  142. SystemVariableKey.CONVERSATION_ID: "abababa",
  143. SystemVariableKey.USER_ID: "aaa",
  144. },
  145. user_inputs={},
  146. environment_variables=[],
  147. )
  148. pool.add(["abc", "output"], "sunny")
  149. credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
  150. provider_instance = ModelProviderFactory().get_provider_instance("openai")
  151. model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
  152. provider_model_bundle = ProviderModelBundle(
  153. configuration=ProviderConfiguration(
  154. tenant_id="1",
  155. provider=provider_instance.get_provider_schema(),
  156. preferred_provider_type=ProviderType.CUSTOM,
  157. using_provider_type=ProviderType.CUSTOM,
  158. system_configuration=SystemConfiguration(enabled=False),
  159. custom_configuration=CustomConfiguration(provider=CustomProviderConfiguration(credentials=credentials)),
  160. model_settings=[],
  161. ),
  162. provider_instance=provider_instance,
  163. model_type_instance=model_type_instance,
  164. )
  165. model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model="gpt-3.5-turbo")
  166. model_config = ModelConfigWithCredentialsEntity(
  167. model="gpt-3.5-turbo",
  168. provider="openai",
  169. mode="chat",
  170. credentials=credentials,
  171. parameters={},
  172. model_schema=model_type_instance.get_model_schema("gpt-3.5-turbo"),
  173. provider_model_bundle=provider_model_bundle,
  174. )
  175. # Mock db.session.close()
  176. db.session.close = MagicMock()
  177. node._fetch_model_config = MagicMock(return_value=(model_instance, model_config))
  178. # execute node
  179. result = node.run(pool)
  180. assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
  181. assert "sunny" in json.dumps(result.process_data)
  182. assert "what's the weather today?" in json.dumps(result.process_data)