|
@@ -5,21 +5,23 @@ import uuid
|
|
from collections.abc import Generator, Mapping
|
|
from collections.abc import Generator, Mapping
|
|
from concurrent.futures import ThreadPoolExecutor, wait
|
|
from concurrent.futures import ThreadPoolExecutor, wait
|
|
from copy import copy, deepcopy
|
|
from copy import copy, deepcopy
|
|
-from typing import Any, Optional
|
|
|
|
|
|
+from typing import Any, Optional, cast
|
|
|
|
|
|
from flask import Flask, current_app
|
|
from flask import Flask, current_app
|
|
|
|
|
|
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
|
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
|
-from core.workflow.entities.node_entities import NodeRunMetadataKey
|
|
|
|
|
|
+from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
|
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
|
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
|
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
|
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
|
from core.workflow.graph_engine.entities.event import (
|
|
from core.workflow.graph_engine.entities.event import (
|
|
BaseIterationEvent,
|
|
BaseIterationEvent,
|
|
GraphEngineEvent,
|
|
GraphEngineEvent,
|
|
GraphRunFailedEvent,
|
|
GraphRunFailedEvent,
|
|
|
|
+ GraphRunPartialSucceededEvent,
|
|
GraphRunStartedEvent,
|
|
GraphRunStartedEvent,
|
|
GraphRunSucceededEvent,
|
|
GraphRunSucceededEvent,
|
|
|
|
+ NodeRunExceptionEvent,
|
|
NodeRunFailedEvent,
|
|
NodeRunFailedEvent,
|
|
NodeRunRetrieverResourceEvent,
|
|
NodeRunRetrieverResourceEvent,
|
|
NodeRunStartedEvent,
|
|
NodeRunStartedEvent,
|
|
@@ -36,7 +38,9 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeSta
|
|
from core.workflow.nodes import NodeType
|
|
from core.workflow.nodes import NodeType
|
|
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
|
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
|
from core.workflow.nodes.base import BaseNode
|
|
from core.workflow.nodes.base import BaseNode
|
|
|
|
+from core.workflow.nodes.base.entities import BaseNodeData
|
|
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
|
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
|
|
|
+from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
|
|
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
|
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
|
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
|
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
|
from extensions.ext_database import db
|
|
from extensions.ext_database import db
|
|
@@ -128,6 +132,7 @@ class GraphEngine:
|
|
def run(self) -> Generator[GraphEngineEvent, None, None]:
|
|
def run(self) -> Generator[GraphEngineEvent, None, None]:
|
|
# trigger graph run start event
|
|
# trigger graph run start event
|
|
yield GraphRunStartedEvent()
|
|
yield GraphRunStartedEvent()
|
|
|
|
+ handle_exceptions = []
|
|
|
|
|
|
try:
|
|
try:
|
|
if self.init_params.workflow_type == WorkflowType.CHAT:
|
|
if self.init_params.workflow_type == WorkflowType.CHAT:
|
|
@@ -140,13 +145,17 @@ class GraphEngine:
|
|
)
|
|
)
|
|
|
|
|
|
# run graph
|
|
# run graph
|
|
- generator = stream_processor.process(self._run(start_node_id=self.graph.root_node_id))
|
|
|
|
-
|
|
|
|
|
|
+ generator = stream_processor.process(
|
|
|
|
+ self._run(start_node_id=self.graph.root_node_id, handle_exceptions=handle_exceptions)
|
|
|
|
+ )
|
|
for item in generator:
|
|
for item in generator:
|
|
try:
|
|
try:
|
|
yield item
|
|
yield item
|
|
if isinstance(item, NodeRunFailedEvent):
|
|
if isinstance(item, NodeRunFailedEvent):
|
|
- yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or "Unknown error.")
|
|
|
|
|
|
+ yield GraphRunFailedEvent(
|
|
|
|
+ error=item.route_node_state.failed_reason or "Unknown error.",
|
|
|
|
+ exceptions_count=len(handle_exceptions),
|
|
|
|
+ )
|
|
return
|
|
return
|
|
elif isinstance(item, NodeRunSucceededEvent):
|
|
elif isinstance(item, NodeRunSucceededEvent):
|
|
if item.node_type == NodeType.END:
|
|
if item.node_type == NodeType.END:
|
|
@@ -172,19 +181,24 @@ class GraphEngine:
|
|
].strip()
|
|
].strip()
|
|
except Exception as e:
|
|
except Exception as e:
|
|
logger.exception("Graph run failed")
|
|
logger.exception("Graph run failed")
|
|
- yield GraphRunFailedEvent(error=str(e))
|
|
|
|
|
|
+ yield GraphRunFailedEvent(error=str(e), exceptions_count=len(handle_exceptions))
|
|
return
|
|
return
|
|
-
|
|
|
|
- # trigger graph run success event
|
|
|
|
- yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs)
|
|
|
|
|
|
+ # count exceptions to determine partial success
|
|
|
|
+ if len(handle_exceptions) > 0:
|
|
|
|
+ yield GraphRunPartialSucceededEvent(
|
|
|
|
+ exceptions_count=len(handle_exceptions), outputs=self.graph_runtime_state.outputs
|
|
|
|
+ )
|
|
|
|
+ else:
|
|
|
|
+ # trigger graph run success event
|
|
|
|
+ yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs)
|
|
self._release_thread()
|
|
self._release_thread()
|
|
except GraphRunFailedError as e:
|
|
except GraphRunFailedError as e:
|
|
- yield GraphRunFailedEvent(error=e.error)
|
|
|
|
|
|
+ yield GraphRunFailedEvent(error=e.error, exceptions_count=len(handle_exceptions))
|
|
self._release_thread()
|
|
self._release_thread()
|
|
return
|
|
return
|
|
except Exception as e:
|
|
except Exception as e:
|
|
logger.exception("Unknown Error when graph running")
|
|
logger.exception("Unknown Error when graph running")
|
|
- yield GraphRunFailedEvent(error=str(e))
|
|
|
|
|
|
+ yield GraphRunFailedEvent(error=str(e), exceptions_count=len(handle_exceptions))
|
|
self._release_thread()
|
|
self._release_thread()
|
|
raise e
|
|
raise e
|
|
|
|
|
|
@@ -198,6 +212,7 @@ class GraphEngine:
|
|
in_parallel_id: Optional[str] = None,
|
|
in_parallel_id: Optional[str] = None,
|
|
parent_parallel_id: Optional[str] = None,
|
|
parent_parallel_id: Optional[str] = None,
|
|
parent_parallel_start_node_id: Optional[str] = None,
|
|
parent_parallel_start_node_id: Optional[str] = None,
|
|
|
|
+ handle_exceptions: list[str] = [],
|
|
) -> Generator[GraphEngineEvent, None, None]:
|
|
) -> Generator[GraphEngineEvent, None, None]:
|
|
parallel_start_node_id = None
|
|
parallel_start_node_id = None
|
|
if in_parallel_id:
|
|
if in_parallel_id:
|
|
@@ -242,7 +257,7 @@ class GraphEngine:
|
|
previous_node_id=previous_node_id,
|
|
previous_node_id=previous_node_id,
|
|
thread_pool_id=self.thread_pool_id,
|
|
thread_pool_id=self.thread_pool_id,
|
|
)
|
|
)
|
|
-
|
|
|
|
|
|
+ node_instance = cast(BaseNode[BaseNodeData], node_instance)
|
|
try:
|
|
try:
|
|
# run node
|
|
# run node
|
|
generator = self._run_node(
|
|
generator = self._run_node(
|
|
@@ -252,6 +267,7 @@ class GraphEngine:
|
|
parallel_start_node_id=parallel_start_node_id,
|
|
parallel_start_node_id=parallel_start_node_id,
|
|
parent_parallel_id=parent_parallel_id,
|
|
parent_parallel_id=parent_parallel_id,
|
|
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
|
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
|
|
|
+ handle_exceptions=handle_exceptions,
|
|
)
|
|
)
|
|
|
|
|
|
for item in generator:
|
|
for item in generator:
|
|
@@ -301,7 +317,12 @@ class GraphEngine:
|
|
|
|
|
|
if len(edge_mappings) == 1:
|
|
if len(edge_mappings) == 1:
|
|
edge = edge_mappings[0]
|
|
edge = edge_mappings[0]
|
|
-
|
|
|
|
|
|
+ if (
|
|
|
|
+ previous_route_node_state.status == RouteNodeState.Status.EXCEPTION
|
|
|
|
+ and node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
|
|
|
|
+ and edge.run_condition is None
|
|
|
|
+ ):
|
|
|
|
+ break
|
|
if edge.run_condition:
|
|
if edge.run_condition:
|
|
result = ConditionManager.get_condition_handler(
|
|
result = ConditionManager.get_condition_handler(
|
|
init_params=self.init_params,
|
|
init_params=self.init_params,
|
|
@@ -334,7 +355,7 @@ class GraphEngine:
|
|
if len(sub_edge_mappings) == 0:
|
|
if len(sub_edge_mappings) == 0:
|
|
continue
|
|
continue
|
|
|
|
|
|
- edge = sub_edge_mappings[0]
|
|
|
|
|
|
+ edge = cast(GraphEdge, sub_edge_mappings[0])
|
|
|
|
|
|
result = ConditionManager.get_condition_handler(
|
|
result = ConditionManager.get_condition_handler(
|
|
init_params=self.init_params,
|
|
init_params=self.init_params,
|
|
@@ -355,6 +376,7 @@ class GraphEngine:
|
|
edge_mappings=sub_edge_mappings,
|
|
edge_mappings=sub_edge_mappings,
|
|
in_parallel_id=in_parallel_id,
|
|
in_parallel_id=in_parallel_id,
|
|
parallel_start_node_id=parallel_start_node_id,
|
|
parallel_start_node_id=parallel_start_node_id,
|
|
|
|
+ handle_exceptions=handle_exceptions,
|
|
)
|
|
)
|
|
|
|
|
|
for item in parallel_generator:
|
|
for item in parallel_generator:
|
|
@@ -369,11 +391,18 @@ class GraphEngine:
|
|
break
|
|
break
|
|
|
|
|
|
next_node_id = final_node_id
|
|
next_node_id = final_node_id
|
|
|
|
+ elif (
|
|
|
|
+ node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
|
|
|
|
+ and node_instance.should_continue_on_error
|
|
|
|
+ and previous_route_node_state.status == RouteNodeState.Status.EXCEPTION
|
|
|
|
+ ):
|
|
|
|
+ break
|
|
else:
|
|
else:
|
|
parallel_generator = self._run_parallel_branches(
|
|
parallel_generator = self._run_parallel_branches(
|
|
edge_mappings=edge_mappings,
|
|
edge_mappings=edge_mappings,
|
|
in_parallel_id=in_parallel_id,
|
|
in_parallel_id=in_parallel_id,
|
|
parallel_start_node_id=parallel_start_node_id,
|
|
parallel_start_node_id=parallel_start_node_id,
|
|
|
|
+ handle_exceptions=handle_exceptions,
|
|
)
|
|
)
|
|
|
|
|
|
for item in parallel_generator:
|
|
for item in parallel_generator:
|
|
@@ -395,6 +424,7 @@ class GraphEngine:
|
|
edge_mappings: list[GraphEdge],
|
|
edge_mappings: list[GraphEdge],
|
|
in_parallel_id: Optional[str] = None,
|
|
in_parallel_id: Optional[str] = None,
|
|
parallel_start_node_id: Optional[str] = None,
|
|
parallel_start_node_id: Optional[str] = None,
|
|
|
|
+ handle_exceptions: list[str] = [],
|
|
) -> Generator[GraphEngineEvent | str, None, None]:
|
|
) -> Generator[GraphEngineEvent | str, None, None]:
|
|
# if nodes has no run conditions, parallel run all nodes
|
|
# if nodes has no run conditions, parallel run all nodes
|
|
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
|
|
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
|
|
@@ -438,6 +468,7 @@ class GraphEngine:
|
|
"parallel_start_node_id": edge.target_node_id,
|
|
"parallel_start_node_id": edge.target_node_id,
|
|
"parent_parallel_id": in_parallel_id,
|
|
"parent_parallel_id": in_parallel_id,
|
|
"parent_parallel_start_node_id": parallel_start_node_id,
|
|
"parent_parallel_start_node_id": parallel_start_node_id,
|
|
|
|
+ "handle_exceptions": handle_exceptions,
|
|
},
|
|
},
|
|
)
|
|
)
|
|
|
|
|
|
@@ -481,6 +512,7 @@ class GraphEngine:
|
|
parallel_start_node_id: str,
|
|
parallel_start_node_id: str,
|
|
parent_parallel_id: Optional[str] = None,
|
|
parent_parallel_id: Optional[str] = None,
|
|
parent_parallel_start_node_id: Optional[str] = None,
|
|
parent_parallel_start_node_id: Optional[str] = None,
|
|
|
|
+ handle_exceptions: list[str] = [],
|
|
) -> None:
|
|
) -> None:
|
|
"""
|
|
"""
|
|
Run parallel nodes
|
|
Run parallel nodes
|
|
@@ -502,6 +534,7 @@ class GraphEngine:
|
|
in_parallel_id=parallel_id,
|
|
in_parallel_id=parallel_id,
|
|
parent_parallel_id=parent_parallel_id,
|
|
parent_parallel_id=parent_parallel_id,
|
|
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
|
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
|
|
|
+ handle_exceptions=handle_exceptions,
|
|
)
|
|
)
|
|
|
|
|
|
for item in generator:
|
|
for item in generator:
|
|
@@ -548,6 +581,7 @@ class GraphEngine:
|
|
parallel_start_node_id: Optional[str] = None,
|
|
parallel_start_node_id: Optional[str] = None,
|
|
parent_parallel_id: Optional[str] = None,
|
|
parent_parallel_id: Optional[str] = None,
|
|
parent_parallel_start_node_id: Optional[str] = None,
|
|
parent_parallel_start_node_id: Optional[str] = None,
|
|
|
|
+ handle_exceptions: list[str] = [],
|
|
) -> Generator[GraphEngineEvent, None, None]:
|
|
) -> Generator[GraphEngineEvent, None, None]:
|
|
"""
|
|
"""
|
|
Run node
|
|
Run node
|
|
@@ -587,19 +621,55 @@ class GraphEngine:
|
|
route_node_state.set_finished(run_result=run_result)
|
|
route_node_state.set_finished(run_result=run_result)
|
|
|
|
|
|
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
|
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
|
- yield NodeRunFailedEvent(
|
|
|
|
- error=route_node_state.failed_reason or "Unknown error.",
|
|
|
|
- id=node_instance.id,
|
|
|
|
- node_id=node_instance.node_id,
|
|
|
|
- node_type=node_instance.node_type,
|
|
|
|
- node_data=node_instance.node_data,
|
|
|
|
- route_node_state=route_node_state,
|
|
|
|
- parallel_id=parallel_id,
|
|
|
|
- parallel_start_node_id=parallel_start_node_id,
|
|
|
|
- parent_parallel_id=parent_parallel_id,
|
|
|
|
- parent_parallel_start_node_id=parent_parallel_start_node_id,
|
|
|
|
- )
|
|
|
|
|
|
+ if node_instance.should_continue_on_error:
|
|
|
|
+ # if run failed, handle error
|
|
|
|
+ run_result = self._handle_continue_on_error(
|
|
|
|
+ node_instance,
|
|
|
|
+ item.run_result,
|
|
|
|
+ self.graph_runtime_state.variable_pool,
|
|
|
|
+ handle_exceptions=handle_exceptions,
|
|
|
|
+ )
|
|
|
|
+ route_node_state.node_run_result = run_result
|
|
|
|
+ route_node_state.status = RouteNodeState.Status.EXCEPTION
|
|
|
|
+ if run_result.outputs:
|
|
|
|
+ for variable_key, variable_value in run_result.outputs.items():
|
|
|
|
+ # append variables to variable pool recursively
|
|
|
|
+ self._append_variables_recursively(
|
|
|
|
+ node_id=node_instance.node_id,
|
|
|
|
+ variable_key_list=[variable_key],
|
|
|
|
+ variable_value=variable_value,
|
|
|
|
+ )
|
|
|
|
+ yield NodeRunExceptionEvent(
|
|
|
|
+ error=run_result.error or "System Error",
|
|
|
|
+ id=node_instance.id,
|
|
|
|
+ node_id=node_instance.node_id,
|
|
|
|
+ node_type=node_instance.node_type,
|
|
|
|
+ node_data=node_instance.node_data,
|
|
|
|
+ route_node_state=route_node_state,
|
|
|
|
+ parallel_id=parallel_id,
|
|
|
|
+ parallel_start_node_id=parallel_start_node_id,
|
|
|
|
+ parent_parallel_id=parent_parallel_id,
|
|
|
|
+ parent_parallel_start_node_id=parent_parallel_start_node_id,
|
|
|
|
+ )
|
|
|
|
+ else:
|
|
|
|
+ yield NodeRunFailedEvent(
|
|
|
|
+ error=route_node_state.failed_reason or "Unknown error.",
|
|
|
|
+ id=node_instance.id,
|
|
|
|
+ node_id=node_instance.node_id,
|
|
|
|
+ node_type=node_instance.node_type,
|
|
|
|
+ node_data=node_instance.node_data,
|
|
|
|
+ route_node_state=route_node_state,
|
|
|
|
+ parallel_id=parallel_id,
|
|
|
|
+ parallel_start_node_id=parallel_start_node_id,
|
|
|
|
+ parent_parallel_id=parent_parallel_id,
|
|
|
|
+ parent_parallel_start_node_id=parent_parallel_start_node_id,
|
|
|
|
+ )
|
|
|
|
+
|
|
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
|
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
|
|
|
+ if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
|
|
|
|
+ node_instance.node_id
|
|
|
|
+ ):
|
|
|
|
+ run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
|
|
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
|
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
|
# plus state total_tokens
|
|
# plus state total_tokens
|
|
self.graph_runtime_state.total_tokens += int(
|
|
self.graph_runtime_state.total_tokens += int(
|
|
@@ -735,6 +805,56 @@ class GraphEngine:
|
|
new_instance.graph_runtime_state.variable_pool = deepcopy(self.graph_runtime_state.variable_pool)
|
|
new_instance.graph_runtime_state.variable_pool = deepcopy(self.graph_runtime_state.variable_pool)
|
|
return new_instance
|
|
return new_instance
|
|
|
|
|
|
|
|
+ def _handle_continue_on_error(
|
|
|
|
+ self,
|
|
|
|
+ node_instance: BaseNode[BaseNodeData],
|
|
|
|
+ error_result: NodeRunResult,
|
|
|
|
+ variable_pool: VariablePool,
|
|
|
|
+ handle_exceptions: list[str] = [],
|
|
|
|
+ ) -> NodeRunResult:
|
|
|
|
+ """
|
|
|
|
+ handle continue on error when self._should_continue_on_error is True
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ :param error_result (NodeRunResult): error run result
|
|
|
|
+ :param variable_pool (VariablePool): variable pool
|
|
|
|
+ :return: excption run result
|
|
|
|
+ """
|
|
|
|
+ # add error message and error type to variable pool
|
|
|
|
+ variable_pool.add([node_instance.node_id, "error_message"], error_result.error)
|
|
|
|
+ variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type)
|
|
|
|
+ # add error message to handle_exceptions
|
|
|
|
+ handle_exceptions.append(error_result.error)
|
|
|
|
+ node_error_args = {
|
|
|
|
+ "status": WorkflowNodeExecutionStatus.EXCEPTION,
|
|
|
|
+ "error": error_result.error,
|
|
|
|
+ "inputs": error_result.inputs,
|
|
|
|
+ "metadata": {
|
|
|
|
+ NodeRunMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy,
|
|
|
|
+ },
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
|
|
|
|
+ return NodeRunResult(
|
|
|
|
+ **node_error_args,
|
|
|
|
+ outputs={
|
|
|
|
+ **node_instance.node_data.default_value_dict,
|
|
|
|
+ "error_message": error_result.error,
|
|
|
|
+ "error_type": error_result.error_type,
|
|
|
|
+ },
|
|
|
|
+ )
|
|
|
|
+ elif node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH:
|
|
|
|
+ if self.graph.edge_mapping.get(node_instance.node_id):
|
|
|
|
+ node_error_args["edge_source_handle"] = FailBranchSourceHandle.FAILED
|
|
|
|
+ return NodeRunResult(
|
|
|
|
+ **node_error_args,
|
|
|
|
+ outputs={
|
|
|
|
+ "error_message": error_result.error,
|
|
|
|
+ "error_type": error_result.error_type,
|
|
|
|
+ },
|
|
|
|
+ )
|
|
|
|
+ return error_result
|
|
|
|
+
|
|
|
|
|
|
class GraphRunFailedError(Exception):
|
|
class GraphRunFailedError(Exception):
|
|
def __init__(self, error: str):
|
|
def __init__(self, error: str):
|