test_llm.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. import json
  2. import os
  3. import time
  4. import uuid
  5. from collections.abc import Generator
  6. from unittest.mock import MagicMock
  7. import pytest
  8. from core.app.entities.app_invoke_entities import InvokeFrom
  9. from core.workflow.entities.variable_pool import VariablePool
  10. from core.workflow.enums import SystemVariableKey
  11. from core.workflow.graph_engine.entities.graph import Graph
  12. from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
  13. from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
  14. from core.workflow.nodes.event import RunCompletedEvent
  15. from core.workflow.nodes.llm.node import LLMNode
  16. from extensions.ext_database import db
  17. from models.enums import UserFrom
  18. from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
  19. from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config
  20. """FOR MOCK FIXTURES, DO NOT REMOVE"""
  21. from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
  22. from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
  23. def init_llm_node(config: dict) -> LLMNode:
  24. graph_config = {
  25. "edges": [
  26. {
  27. "id": "start-source-next-target",
  28. "source": "start",
  29. "target": "llm",
  30. },
  31. ],
  32. "nodes": [{"data": {"type": "start"}, "id": "start"}, config],
  33. }
  34. graph = Graph.init(graph_config=graph_config)
  35. init_params = GraphInitParams(
  36. tenant_id="1",
  37. app_id="1",
  38. workflow_type=WorkflowType.WORKFLOW,
  39. workflow_id="1",
  40. graph_config=graph_config,
  41. user_id="1",
  42. user_from=UserFrom.ACCOUNT,
  43. invoke_from=InvokeFrom.DEBUGGER,
  44. call_depth=0,
  45. )
  46. # construct variable pool
  47. variable_pool = VariablePool(
  48. system_variables={
  49. SystemVariableKey.QUERY: "what's the weather today?",
  50. SystemVariableKey.FILES: [],
  51. SystemVariableKey.CONVERSATION_ID: "abababa",
  52. SystemVariableKey.USER_ID: "aaa",
  53. },
  54. user_inputs={},
  55. environment_variables=[],
  56. conversation_variables=[],
  57. )
  58. variable_pool.add(["abc", "output"], "sunny")
  59. node = LLMNode(
  60. id=str(uuid.uuid4()),
  61. graph_init_params=init_params,
  62. graph=graph,
  63. graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
  64. config=config,
  65. )
  66. return node
  67. def test_execute_llm(setup_model_mock):
  68. node = init_llm_node(
  69. config={
  70. "id": "llm",
  71. "data": {
  72. "title": "123",
  73. "type": "llm",
  74. "model": {
  75. "provider": "langgenius/openai/openai",
  76. "name": "gpt-3.5-turbo",
  77. "mode": "chat",
  78. "completion_params": {},
  79. },
  80. "prompt_template": [
  81. {"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."},
  82. {"role": "user", "text": "{{#sys.query#}}"},
  83. ],
  84. "memory": None,
  85. "context": {"enabled": False},
  86. "vision": {"enabled": False},
  87. },
  88. },
  89. )
  90. credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
  91. # Mock db.session.close()
  92. db.session.close = MagicMock()
  93. node._fetch_model_config = get_mocked_fetch_model_config(
  94. provider="langgenius/openai/openai",
  95. model="gpt-3.5-turbo",
  96. mode="chat",
  97. credentials=credentials,
  98. )
  99. # execute node
  100. result = node._run()
  101. assert isinstance(result, Generator)
  102. for item in result:
  103. if isinstance(item, RunCompletedEvent):
  104. assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
  105. assert item.run_result.process_data is not None
  106. assert item.run_result.outputs is not None
  107. assert item.run_result.outputs.get("text") is not None
  108. assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
  109. @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
  110. def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_model_mock):
  111. """
  112. Test execute LLM node with jinja2
  113. """
  114. node = init_llm_node(
  115. config={
  116. "id": "llm",
  117. "data": {
  118. "title": "123",
  119. "type": "llm",
  120. "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
  121. "prompt_config": {
  122. "jinja2_variables": [
  123. {"variable": "sys_query", "value_selector": ["sys", "query"]},
  124. {"variable": "output", "value_selector": ["abc", "output"]},
  125. ]
  126. },
  127. "prompt_template": [
  128. {
  129. "role": "system",
  130. "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}",
  131. "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.",
  132. "edition_type": "jinja2",
  133. },
  134. {
  135. "role": "user",
  136. "text": "{{#sys.query#}}",
  137. "jinja2_text": "{{sys_query}}",
  138. "edition_type": "basic",
  139. },
  140. ],
  141. "memory": None,
  142. "context": {"enabled": False},
  143. "vision": {"enabled": False},
  144. },
  145. },
  146. )
  147. credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
  148. # Mock db.session.close()
  149. db.session.close = MagicMock()
  150. node._fetch_model_config = get_mocked_fetch_model_config(
  151. provider="langgenius/openai/openai",
  152. model="gpt-3.5-turbo",
  153. mode="chat",
  154. credentials=credentials,
  155. )
  156. # execute node
  157. result = node._run()
  158. for item in result:
  159. if isinstance(item, RunCompletedEvent):
  160. assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
  161. assert item.run_result.process_data is not None
  162. assert "sunny" in json.dumps(item.run_result.process_data)
  163. assert "what's the weather today?" in json.dumps(item.run_result.process_data)