瀏覽代碼

fix(api): fix fail branch functionality for `WorkflowTool` (#15966)

QuantumGhost 1 月之前
父節點
當前提交
2b4d1cf1db

+ 3 - 5
api/core/tools/workflow_as_tool/tool.py

@@ -7,6 +7,7 @@ from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
 from core.tools.__base.tool import Tool
 from core.tools.__base.tool_runtime import ToolRuntime
 from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
+from core.tools.errors import ToolInvokeError
 from extensions.ext_database import db
 from factories.file_factory import build_from_mapping
 from models.account import Account
@@ -96,11 +97,8 @@ class WorkflowTool(Tool):
         assert isinstance(result, dict)
         data = result.get("data", {})
 
-        if data.get("error"):
-            raise Exception(data.get("error"))
-
-        if data.get("error"):
-            raise Exception(data.get("error"))
+        if err := data.get("error"):
+            raise ToolInvokeError(err)
 
         outputs = data.get("outputs")
         if outputs is None:

+ 3 - 1
api/core/workflow/nodes/tool/tool_node.py

@@ -9,6 +9,7 @@ from core.file import File, FileTransferMethod
 from core.plugin.manager.exc import PluginDaemonClientSideError
 from core.plugin.manager.plugin import PluginInstallationManager
 from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
+from core.tools.errors import ToolInvokeError
 from core.tools.tool_engine import ToolEngine
 from core.tools.utils.message_transformer import ToolFileMessageTransformer
 from core.variables.segments import ArrayAnySegment
@@ -119,13 +120,14 @@ class ToolNode(BaseNode[ToolNodeData]):
         try:
             # convert tool messages
             yield from self._transform_message(message_stream, tool_info, parameters_for_log)
-        except PluginDaemonClientSideError as e:
+        except (PluginDaemonClientSideError, ToolInvokeError) as e:
             yield RunCompletedEvent(
                 run_result=NodeRunResult(
                     status=WorkflowNodeExecutionStatus.FAILED,
                     inputs=parameters_for_log,
                     metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
                     error=f"Failed to transform tool message: {str(e)}",
+                    error_type=type(e).__name__,
                 )
             )
 

+ 0 - 0
api/tests/unit_tests/core/tools/workflow_as_tool/__init__.py


+ 49 - 0
api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py

@@ -0,0 +1,49 @@
+import pytest
+
+from core.app.entities.app_invoke_entities import InvokeFrom
+from core.tools.__base.tool_runtime import ToolRuntime
+from core.tools.entities.common_entities import I18nObject
+from core.tools.entities.tool_entities import ToolEntity, ToolIdentity
+from core.tools.errors import ToolInvokeError
+from core.tools.workflow_as_tool.tool import WorkflowTool
+
+
+def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch):
+    """Ensure that WorkflowTool will throw a `ToolInvokeError` exception when
+    `WorkflowAppGenerator.generate` returns a result with `error` key inside
+    the `data` element.
+    """
+    entity = ToolEntity(
+        identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
+        parameters=[],
+        description=None,
+        output_schema=None,
+        has_runtime_parameters=False,
+    )
+    runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
+    tool = WorkflowTool(
+        workflow_app_id="",
+        workflow_as_tool_id="",
+        version="1",
+        workflow_entities={},
+        workflow_call_depth=1,
+        entity=entity,
+        runtime=runtime,
+    )
+
+    # needs to patch those methods to avoid database access.
+    monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
+    monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
+    monkeypatch.setattr(tool, "_get_user", lambda *args, **kwargs: None)
+
+    # replace `WorkflowAppGenerator.generate` 's return value.
+    monkeypatch.setattr(
+        "core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate",
+        lambda *args, **kwargs: {"data": {"error": "oops"}},
+    )
+
+    with pytest.raises(ToolInvokeError) as exc_info:
+        # WorkflowTool always returns a generator, so we need to iterate to
+        # actually `run` the tool.
+        list(tool.invoke("test_user", {}))
+    assert exc_info.value.args == ("oops",)

+ 1 - 0
api/tests/unit_tests/core/workflow/nodes/tool/__init__.py

@@ -0,0 +1 @@
+

+ 110 - 0
api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py

@@ -0,0 +1,110 @@
+from collections.abc import Generator
+
+import pytest
+
+from core.app.entities.app_invoke_entities import InvokeFrom
+from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
+from core.tools.errors import ToolInvokeError
+from core.workflow.entities.node_entities import NodeRunResult
+from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
+from core.workflow.nodes.answer import AnswerStreamGenerateRoute
+from core.workflow.nodes.end import EndStreamParam
+from core.workflow.nodes.enums import ErrorStrategy
+from core.workflow.nodes.event import RunCompletedEvent
+from core.workflow.nodes.tool import ToolNode
+from core.workflow.nodes.tool.entities import ToolNodeData
+from models import UserFrom, WorkflowNodeExecutionStatus, WorkflowType
+
+
+def _create_tool_node():
+    data = ToolNodeData(
+        title="Test Tool",
+        tool_parameters={},
+        provider_id="test_tool",
+        provider_type=ToolProviderType.WORKFLOW,
+        provider_name="test tool",
+        tool_name="test tool",
+        tool_label="test tool",
+        tool_configurations={},
+        plugin_unique_identifier=None,
+        desc="Exception handling test tool",
+        error_strategy=ErrorStrategy.FAIL_BRANCH,
+        version="1",
+    )
+    variable_pool = VariablePool(
+        system_variables={},
+        user_inputs={},
+    )
+    node = ToolNode(
+        id="1",
+        config={
+            "id": "1",
+            "data": data.model_dump(),
+        },
+        graph_init_params=GraphInitParams(
+            tenant_id="1",
+            app_id="1",
+            workflow_type=WorkflowType.WORKFLOW,
+            workflow_id="1",
+            graph_config={},
+            user_id="1",
+            user_from=UserFrom.ACCOUNT,
+            invoke_from=InvokeFrom.SERVICE_API,
+            call_depth=0,
+        ),
+        graph=Graph(
+            root_node_id="1",
+            answer_stream_generate_routes=AnswerStreamGenerateRoute(
+                answer_dependencies={},
+                answer_generate_route={},
+            ),
+            end_stream_param=EndStreamParam(
+                end_dependencies={},
+                end_stream_variable_selector_mapping={},
+            ),
+        ),
+        graph_runtime_state=GraphRuntimeState(
+            variable_pool=variable_pool,
+            start_at=0,
+        ),
+    )
+    return node
+
+
+class MockToolRuntime:
+    def get_merged_runtime_parameters(self):
+        pass
+
+
+def mock_message_stream() -> Generator[ToolInvokeMessage, None, None]:
+    yield from []
+    raise ToolInvokeError("oops")
+
+
+def test_tool_node_on_tool_invoke_error(monkeypatch: pytest.MonkeyPatch):
+    """Ensure that ToolNode can handle ToolInvokeError when transforming
+    messages generated by ToolEngine.generic_invoke.
+    """
+    tool_node = _create_tool_node()
+
+    # Need to patch ToolManager and ToolEngine so that we don't
+    # have to set up a database.
+    monkeypatch.setattr(
+        "core.tools.tool_manager.ToolManager.get_workflow_tool_runtime", lambda *args, **kwargs: MockToolRuntime()
+    )
+    monkeypatch.setattr(
+        "core.tools.tool_engine.ToolEngine.generic_invoke",
+        lambda *args, **kwargs: mock_message_stream(),
+    )
+
+    streams = list(tool_node._run())
+    assert len(streams) == 1
+    stream = streams[0]
+    assert isinstance(stream, RunCompletedEvent)
+    result = stream.run_result
+    assert isinstance(result, NodeRunResult)
+    assert result.status == WorkflowNodeExecutionStatus.FAILED
+    assert "oops" in result.error
+    assert "Failed to transform tool message:" in result.error
+    assert result.error_type == "ToolInvokeError"