|
@@ -0,0 +1,110 @@
|
|
|
+from collections.abc import Generator
|
|
|
+
|
|
|
+import pytest
|
|
|
+
|
|
|
+from core.app.entities.app_invoke_entities import InvokeFrom
|
|
|
+from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
|
|
+from core.tools.errors import ToolInvokeError
|
|
|
+from core.workflow.entities.node_entities import NodeRunResult
|
|
|
+from core.workflow.entities.variable_pool import VariablePool
|
|
|
+from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
|
|
+from core.workflow.nodes.answer import AnswerStreamGenerateRoute
|
|
|
+from core.workflow.nodes.end import EndStreamParam
|
|
|
+from core.workflow.nodes.enums import ErrorStrategy
|
|
|
+from core.workflow.nodes.event import RunCompletedEvent
|
|
|
+from core.workflow.nodes.tool import ToolNode
|
|
|
+from core.workflow.nodes.tool.entities import ToolNodeData
|
|
|
+from models import UserFrom, WorkflowNodeExecutionStatus, WorkflowType
|
|
|
+
|
|
|
+
|
|
|
+def _create_tool_node():
|
|
|
+ data = ToolNodeData(
|
|
|
+ title="Test Tool",
|
|
|
+ tool_parameters={},
|
|
|
+ provider_id="test_tool",
|
|
|
+ provider_type=ToolProviderType.WORKFLOW,
|
|
|
+ provider_name="test tool",
|
|
|
+ tool_name="test tool",
|
|
|
+ tool_label="test tool",
|
|
|
+ tool_configurations={},
|
|
|
+ plugin_unique_identifier=None,
|
|
|
+ desc="Exception handling test tool",
|
|
|
+ error_strategy=ErrorStrategy.FAIL_BRANCH,
|
|
|
+ version="1",
|
|
|
+ )
|
|
|
+ variable_pool = VariablePool(
|
|
|
+ system_variables={},
|
|
|
+ user_inputs={},
|
|
|
+ )
|
|
|
+ node = ToolNode(
|
|
|
+ id="1",
|
|
|
+ config={
|
|
|
+ "id": "1",
|
|
|
+ "data": data.model_dump(),
|
|
|
+ },
|
|
|
+ graph_init_params=GraphInitParams(
|
|
|
+ tenant_id="1",
|
|
|
+ app_id="1",
|
|
|
+ workflow_type=WorkflowType.WORKFLOW,
|
|
|
+ workflow_id="1",
|
|
|
+ graph_config={},
|
|
|
+ user_id="1",
|
|
|
+ user_from=UserFrom.ACCOUNT,
|
|
|
+ invoke_from=InvokeFrom.SERVICE_API,
|
|
|
+ call_depth=0,
|
|
|
+ ),
|
|
|
+ graph=Graph(
|
|
|
+ root_node_id="1",
|
|
|
+ answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
|
|
+ answer_dependencies={},
|
|
|
+ answer_generate_route={},
|
|
|
+ ),
|
|
|
+ end_stream_param=EndStreamParam(
|
|
|
+ end_dependencies={},
|
|
|
+ end_stream_variable_selector_mapping={},
|
|
|
+ ),
|
|
|
+ ),
|
|
|
+ graph_runtime_state=GraphRuntimeState(
|
|
|
+ variable_pool=variable_pool,
|
|
|
+ start_at=0,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ return node
|
|
|
+
|
|
|
+
|
|
|
+class MockToolRuntime:
|
|
|
+ def get_merged_runtime_parameters(self):
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
+def mock_message_stream() -> Generator[ToolInvokeMessage, None, None]:
|
|
|
+ yield from []
|
|
|
+ raise ToolInvokeError("oops")
|
|
|
+
|
|
|
+
|
|
|
+def test_tool_node_on_tool_invoke_error(monkeypatch: pytest.MonkeyPatch):
|
|
|
+ """Ensure that ToolNode can handle ToolInvokeError when transforming
|
|
|
+ messages generated by ToolEngine.generic_invoke.
|
|
|
+ """
|
|
|
+ tool_node = _create_tool_node()
|
|
|
+
|
|
|
+ # Need to patch ToolManager and ToolEngine so that we don't
|
|
|
+ # have to set up a database.
|
|
|
+ monkeypatch.setattr(
|
|
|
+ "core.tools.tool_manager.ToolManager.get_workflow_tool_runtime", lambda *args, **kwargs: MockToolRuntime()
|
|
|
+ )
|
|
|
+ monkeypatch.setattr(
|
|
|
+ "core.tools.tool_engine.ToolEngine.generic_invoke",
|
|
|
+ lambda *args, **kwargs: mock_message_stream(),
|
|
|
+ )
|
|
|
+
|
|
|
+ streams = list(tool_node._run())
|
|
|
+ assert len(streams) == 1
|
|
|
+ stream = streams[0]
|
|
|
+ assert isinstance(stream, RunCompletedEvent)
|
|
|
+ result = stream.run_result
|
|
|
+ assert isinstance(result, NodeRunResult)
|
|
|
+ assert result.status == WorkflowNodeExecutionStatus.FAILED
|
|
|
+ assert "oops" in result.error
|
|
|
+ assert "Failed to transform tool message:" in result.error
|
|
|
+ assert result.error_type == "ToolInvokeError"
|