test_tool.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import time
  2. import uuid
  3. from unittest.mock import MagicMock
  4. from core.app.entities.app_invoke_entities import InvokeFrom
  5. from core.tools.utils.configuration import ToolParameterConfigurationManager
  6. from core.workflow.entities.variable_pool import VariablePool
  7. from core.workflow.enums import SystemVariableKey
  8. from core.workflow.graph_engine.entities.graph import Graph
  9. from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
  10. from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
  11. from core.workflow.nodes.event.event import RunCompletedEvent
  12. from core.workflow.nodes.tool.tool_node import ToolNode
  13. from models.enums import UserFrom
  14. from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
  15. def init_tool_node(config: dict):
  16. graph_config = {
  17. "edges": [
  18. {
  19. "id": "start-source-next-target",
  20. "source": "start",
  21. "target": "1",
  22. },
  23. ],
  24. "nodes": [{"data": {"type": "start"}, "id": "start"}, config],
  25. }
  26. graph = Graph.init(graph_config=graph_config)
  27. init_params = GraphInitParams(
  28. tenant_id="1",
  29. app_id="1",
  30. workflow_type=WorkflowType.WORKFLOW,
  31. workflow_id="1",
  32. graph_config=graph_config,
  33. user_id="1",
  34. user_from=UserFrom.ACCOUNT,
  35. invoke_from=InvokeFrom.DEBUGGER,
  36. call_depth=0,
  37. )
  38. # construct variable pool
  39. variable_pool = VariablePool(
  40. system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
  41. user_inputs={},
  42. environment_variables=[],
  43. conversation_variables=[],
  44. )
  45. return ToolNode(
  46. id=str(uuid.uuid4()),
  47. graph_init_params=init_params,
  48. graph=graph,
  49. graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
  50. config=config,
  51. )
  52. def test_tool_variable_invoke():
  53. node = init_tool_node(
  54. config={
  55. "id": "1",
  56. "data": {
  57. "title": "a",
  58. "desc": "a",
  59. "provider_id": "time",
  60. "provider_type": "builtin",
  61. "provider_name": "time",
  62. "tool_name": "current_time",
  63. "tool_label": "current_time",
  64. "tool_configurations": {},
  65. "tool_parameters": {},
  66. },
  67. }
  68. )
  69. ToolParameterConfigurationManager.decrypt_tool_parameters = MagicMock(return_value={"format": "%Y-%m-%d %H:%M:%S"})
  70. node.graph_runtime_state.variable_pool.add(["1", "123", "args1"], "1+1")
  71. # execute node
  72. result = node._run()
  73. for item in result:
  74. if isinstance(item, RunCompletedEvent):
  75. assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
  76. assert item.run_result.outputs is not None
  77. assert item.run_result.outputs.get("text") is not None
  78. def test_tool_mixed_invoke():
  79. node = init_tool_node(
  80. config={
  81. "id": "1",
  82. "data": {
  83. "title": "a",
  84. "desc": "a",
  85. "provider_id": "time",
  86. "provider_type": "builtin",
  87. "provider_name": "time",
  88. "tool_name": "current_time",
  89. "tool_label": "current_time",
  90. "tool_configurations": {
  91. "format": "%Y-%m-%d %H:%M:%S",
  92. },
  93. "tool_parameters": {},
  94. },
  95. }
  96. )
  97. ToolParameterConfigurationManager.decrypt_tool_parameters = MagicMock(return_value={"format": "%Y-%m-%d %H:%M:%S"})
  98. # execute node
  99. result = node._run()
  100. for item in result:
  101. if isinstance(item, RunCompletedEvent):
  102. assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
  103. assert item.run_result.outputs is not None
  104. assert item.run_result.outputs.get("text") is not None