test_tool.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import time
  2. import uuid
  3. from core.app.entities.app_invoke_entities import InvokeFrom
  4. from core.workflow.entities.node_entities import NodeRunResult
  5. from core.workflow.entities.variable_pool import VariablePool
  6. from core.workflow.enums import SystemVariableKey
  7. from core.workflow.graph_engine.entities.graph import Graph
  8. from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
  9. from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
  10. from core.workflow.nodes.tool.tool_node import ToolNode
  11. from models.enums import UserFrom
  12. from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
  13. def init_tool_node(config: dict):
  14. graph_config = {
  15. "edges": [
  16. {
  17. "id": "start-source-next-target",
  18. "source": "start",
  19. "target": "1",
  20. },
  21. ],
  22. "nodes": [{"data": {"type": "start"}, "id": "start"}, config],
  23. }
  24. graph = Graph.init(graph_config=graph_config)
  25. init_params = GraphInitParams(
  26. tenant_id="1",
  27. app_id="1",
  28. workflow_type=WorkflowType.WORKFLOW,
  29. workflow_id="1",
  30. graph_config=graph_config,
  31. user_id="1",
  32. user_from=UserFrom.ACCOUNT,
  33. invoke_from=InvokeFrom.DEBUGGER,
  34. call_depth=0,
  35. )
  36. # construct variable pool
  37. variable_pool = VariablePool(
  38. system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
  39. user_inputs={},
  40. environment_variables=[],
  41. conversation_variables=[],
  42. )
  43. return ToolNode(
  44. id=str(uuid.uuid4()),
  45. graph_init_params=init_params,
  46. graph=graph,
  47. graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
  48. config=config,
  49. )
  50. def test_tool_variable_invoke():
  51. node = init_tool_node(
  52. config={
  53. "id": "1",
  54. "data": {
  55. "title": "a",
  56. "desc": "a",
  57. "provider_id": "maths",
  58. "provider_type": "builtin",
  59. "provider_name": "maths",
  60. "tool_name": "eval_expression",
  61. "tool_label": "eval_expression",
  62. "tool_configurations": {},
  63. "tool_parameters": {
  64. "expression": {
  65. "type": "variable",
  66. "value": ["1", "123", "args1"],
  67. }
  68. },
  69. },
  70. }
  71. )
  72. node.graph_runtime_state.variable_pool.add(["1", "123", "args1"], "1+1")
  73. # execute node
  74. result = node._run()
  75. assert isinstance(result, NodeRunResult)
  76. assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
  77. assert result.outputs is not None
  78. assert "2" in result.outputs["text"]
  79. assert result.outputs["files"] == []
  80. def test_tool_mixed_invoke():
  81. node = init_tool_node(
  82. config={
  83. "id": "1",
  84. "data": {
  85. "title": "a",
  86. "desc": "a",
  87. "provider_id": "maths",
  88. "provider_type": "builtin",
  89. "provider_name": "maths",
  90. "tool_name": "eval_expression",
  91. "tool_label": "eval_expression",
  92. "tool_configurations": {},
  93. "tool_parameters": {
  94. "expression": {
  95. "type": "mixed",
  96. "value": "{{#1.args1#}}",
  97. }
  98. },
  99. },
  100. }
  101. )
  102. node.graph_runtime_state.variable_pool.add(["1", "args1"], "1+1")
  103. # execute node
  104. result = node._run()
  105. assert isinstance(result, NodeRunResult)
  106. assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
  107. assert result.outputs is not None
  108. assert "2" in result.outputs["text"]
  109. assert result.outputs["files"] == []