|
@@ -5,6 +5,7 @@ import uuid
|
|
|
from collections.abc import Generator, Mapping
|
|
|
from concurrent.futures import ThreadPoolExecutor, wait
|
|
|
from copy import copy, deepcopy
|
|
|
+from datetime import UTC, datetime
|
|
|
from typing import Any, Optional, cast
|
|
|
|
|
|
from flask import Flask, current_app
|
|
@@ -25,6 +26,7 @@ from core.workflow.graph_engine.entities.event import (
|
|
|
NodeRunExceptionEvent,
|
|
|
NodeRunFailedEvent,
|
|
|
NodeRunRetrieverResourceEvent,
|
|
|
+ NodeRunRetryEvent,
|
|
|
NodeRunStartedEvent,
|
|
|
NodeRunStreamChunkEvent,
|
|
|
NodeRunSucceededEvent,
|
|
@@ -581,7 +583,7 @@ class GraphEngine:
|
|
|
|
|
|
def _run_node(
|
|
|
self,
|
|
|
- node_instance: BaseNode,
|
|
|
+ node_instance: BaseNode[BaseNodeData],
|
|
|
route_node_state: RouteNodeState,
|
|
|
parallel_id: Optional[str] = None,
|
|
|
parallel_start_node_id: Optional[str] = None,
|
|
@@ -607,36 +609,121 @@ class GraphEngine:
|
|
|
)
|
|
|
|
|
|
db.session.close()
|
|
|
+ max_retries = node_instance.node_data.retry_config.max_retries
|
|
|
+ retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
|
|
|
+ retries = 0
|
|
|
+ shoudl_continue_retry = True
|
|
|
+ while shoudl_continue_retry and retries <= max_retries:
|
|
|
+ try:
|
|
|
+ # run node
|
|
|
+ retry_start_at = datetime.now(UTC).replace(tzinfo=None)
|
|
|
+ generator = node_instance.run()
|
|
|
+ for item in generator:
|
|
|
+ if isinstance(item, GraphEngineEvent):
|
|
|
+ if isinstance(item, BaseIterationEvent):
|
|
|
+ # add parallel info to iteration event
|
|
|
+ item.parallel_id = parallel_id
|
|
|
+ item.parallel_start_node_id = parallel_start_node_id
|
|
|
+ item.parent_parallel_id = parent_parallel_id
|
|
|
+ item.parent_parallel_start_node_id = parent_parallel_start_node_id
|
|
|
+
|
|
|
+ yield item
|
|
|
+ else:
|
|
|
+ if isinstance(item, RunCompletedEvent):
|
|
|
+ run_result = item.run_result
|
|
|
+ if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
|
|
+ if (
|
|
|
+ retries == max_retries
|
|
|
+ and node_instance.node_type == NodeType.HTTP_REQUEST
|
|
|
+ and run_result.outputs
|
|
|
+ and not node_instance.should_continue_on_error
|
|
|
+ ):
|
|
|
+ run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
|
|
+ if node_instance.should_retry and retries < max_retries:
|
|
|
+ retries += 1
|
|
|
+ self.graph_runtime_state.node_run_steps += 1
|
|
|
+ route_node_state.node_run_result = run_result
|
|
|
+ yield NodeRunRetryEvent(
|
|
|
+ id=node_instance.id,
|
|
|
+ node_id=node_instance.node_id,
|
|
|
+ node_type=node_instance.node_type,
|
|
|
+ node_data=node_instance.node_data,
|
|
|
+ route_node_state=route_node_state,
|
|
|
+ error=run_result.error,
|
|
|
+ retry_index=retries,
|
|
|
+ parallel_id=parallel_id,
|
|
|
+ parallel_start_node_id=parallel_start_node_id,
|
|
|
+ parent_parallel_id=parent_parallel_id,
|
|
|
+ parent_parallel_start_node_id=parent_parallel_start_node_id,
|
|
|
+ start_at=retry_start_at,
|
|
|
+ start_index=self.graph_runtime_state.node_run_steps,
|
|
|
+ )
|
|
|
+ time.sleep(retry_interval)
|
|
|
+ continue
|
|
|
+ route_node_state.set_finished(run_result=run_result)
|
|
|
+
|
|
|
+ if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
|
|
+ if node_instance.should_continue_on_error:
|
|
|
+ # if run failed, handle error
|
|
|
+ run_result = self._handle_continue_on_error(
|
|
|
+ node_instance,
|
|
|
+ item.run_result,
|
|
|
+ self.graph_runtime_state.variable_pool,
|
|
|
+ handle_exceptions=handle_exceptions,
|
|
|
+ )
|
|
|
+ route_node_state.node_run_result = run_result
|
|
|
+ route_node_state.status = RouteNodeState.Status.EXCEPTION
|
|
|
+ if run_result.outputs:
|
|
|
+ for variable_key, variable_value in run_result.outputs.items():
|
|
|
+ # append variables to variable pool recursively
|
|
|
+ self._append_variables_recursively(
|
|
|
+ node_id=node_instance.node_id,
|
|
|
+ variable_key_list=[variable_key],
|
|
|
+ variable_value=variable_value,
|
|
|
+ )
|
|
|
+ yield NodeRunExceptionEvent(
|
|
|
+ error=run_result.error or "System Error",
|
|
|
+ id=node_instance.id,
|
|
|
+ node_id=node_instance.node_id,
|
|
|
+ node_type=node_instance.node_type,
|
|
|
+ node_data=node_instance.node_data,
|
|
|
+ route_node_state=route_node_state,
|
|
|
+ parallel_id=parallel_id,
|
|
|
+ parallel_start_node_id=parallel_start_node_id,
|
|
|
+ parent_parallel_id=parent_parallel_id,
|
|
|
+ parent_parallel_start_node_id=parent_parallel_start_node_id,
|
|
|
+ )
|
|
|
+ shoudl_continue_retry = False
|
|
|
+ else:
|
|
|
+ yield NodeRunFailedEvent(
|
|
|
+ error=route_node_state.failed_reason or "Unknown error.",
|
|
|
+ id=node_instance.id,
|
|
|
+ node_id=node_instance.node_id,
|
|
|
+ node_type=node_instance.node_type,
|
|
|
+ node_data=node_instance.node_data,
|
|
|
+ route_node_state=route_node_state,
|
|
|
+ parallel_id=parallel_id,
|
|
|
+ parallel_start_node_id=parallel_start_node_id,
|
|
|
+ parent_parallel_id=parent_parallel_id,
|
|
|
+ parent_parallel_start_node_id=parent_parallel_start_node_id,
|
|
|
+ )
|
|
|
+ shoudl_continue_retry = False
|
|
|
+ elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
|
|
+ if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
|
|
|
+ node_instance.node_id
|
|
|
+ ):
|
|
|
+ run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
|
|
|
+ if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
|
|
+ # plus state total_tokens
|
|
|
+ self.graph_runtime_state.total_tokens += int(
|
|
|
+ run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
|
|
|
+ )
|
|
|
|
|
|
- try:
|
|
|
- # run node
|
|
|
- generator = node_instance.run()
|
|
|
- for item in generator:
|
|
|
- if isinstance(item, GraphEngineEvent):
|
|
|
- if isinstance(item, BaseIterationEvent):
|
|
|
- # add parallel info to iteration event
|
|
|
- item.parallel_id = parallel_id
|
|
|
- item.parallel_start_node_id = parallel_start_node_id
|
|
|
- item.parent_parallel_id = parent_parallel_id
|
|
|
- item.parent_parallel_start_node_id = parent_parallel_start_node_id
|
|
|
+ if run_result.llm_usage:
|
|
|
+ # use the latest usage
|
|
|
+ self.graph_runtime_state.llm_usage += run_result.llm_usage
|
|
|
|
|
|
- yield item
|
|
|
- else:
|
|
|
- if isinstance(item, RunCompletedEvent):
|
|
|
- run_result = item.run_result
|
|
|
- route_node_state.set_finished(run_result=run_result)
|
|
|
-
|
|
|
- if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
|
|
- if node_instance.should_continue_on_error:
|
|
|
- # if run failed, handle error
|
|
|
- run_result = self._handle_continue_on_error(
|
|
|
- node_instance,
|
|
|
- item.run_result,
|
|
|
- self.graph_runtime_state.variable_pool,
|
|
|
- handle_exceptions=handle_exceptions,
|
|
|
- )
|
|
|
- route_node_state.node_run_result = run_result
|
|
|
- route_node_state.status = RouteNodeState.Status.EXCEPTION
|
|
|
+ # append node output variables to variable pool
|
|
|
if run_result.outputs:
|
|
|
for variable_key, variable_value in run_result.outputs.items():
|
|
|
# append variables to variable pool recursively
|
|
@@ -645,21 +732,23 @@ class GraphEngine:
|
|
|
variable_key_list=[variable_key],
|
|
|
variable_value=variable_value,
|
|
|
)
|
|
|
- yield NodeRunExceptionEvent(
|
|
|
- error=run_result.error or "System Error",
|
|
|
- id=node_instance.id,
|
|
|
- node_id=node_instance.node_id,
|
|
|
- node_type=node_instance.node_type,
|
|
|
- node_data=node_instance.node_data,
|
|
|
- route_node_state=route_node_state,
|
|
|
- parallel_id=parallel_id,
|
|
|
- parallel_start_node_id=parallel_start_node_id,
|
|
|
- parent_parallel_id=parent_parallel_id,
|
|
|
- parent_parallel_start_node_id=parent_parallel_start_node_id,
|
|
|
- )
|
|
|
- else:
|
|
|
- yield NodeRunFailedEvent(
|
|
|
- error=route_node_state.failed_reason or "Unknown error.",
|
|
|
+
|
|
|
+ # add parallel info to run result metadata
|
|
|
+ if parallel_id and parallel_start_node_id:
|
|
|
+ if not run_result.metadata:
|
|
|
+ run_result.metadata = {}
|
|
|
+
|
|
|
+ run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
|
|
|
+ run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = (
|
|
|
+ parallel_start_node_id
|
|
|
+ )
|
|
|
+ if parent_parallel_id and parent_parallel_start_node_id:
|
|
|
+ run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
|
|
|
+ run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
|
|
|
+ parent_parallel_start_node_id
|
|
|
+ )
|
|
|
+
|
|
|
+ yield NodeRunSucceededEvent(
|
|
|
id=node_instance.id,
|
|
|
node_id=node_instance.node_id,
|
|
|
node_type=node_instance.node_type,
|
|
@@ -670,108 +759,59 @@ class GraphEngine:
|
|
|
parent_parallel_id=parent_parallel_id,
|
|
|
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
|
|
)
|
|
|
+ shoudl_continue_retry = False
|
|
|
|
|
|
- elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
|
|
- if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
|
|
|
- node_instance.node_id
|
|
|
- ):
|
|
|
- run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
|
|
|
- if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
|
|
- # plus state total_tokens
|
|
|
- self.graph_runtime_state.total_tokens += int(
|
|
|
- run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
|
|
|
- )
|
|
|
-
|
|
|
- if run_result.llm_usage:
|
|
|
- # use the latest usage
|
|
|
- self.graph_runtime_state.llm_usage += run_result.llm_usage
|
|
|
-
|
|
|
- # append node output variables to variable pool
|
|
|
- if run_result.outputs:
|
|
|
- for variable_key, variable_value in run_result.outputs.items():
|
|
|
- # append variables to variable pool recursively
|
|
|
- self._append_variables_recursively(
|
|
|
- node_id=node_instance.node_id,
|
|
|
- variable_key_list=[variable_key],
|
|
|
- variable_value=variable_value,
|
|
|
- )
|
|
|
-
|
|
|
- # add parallel info to run result metadata
|
|
|
- if parallel_id and parallel_start_node_id:
|
|
|
- if not run_result.metadata:
|
|
|
- run_result.metadata = {}
|
|
|
-
|
|
|
- run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
|
|
|
- run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
|
|
|
- if parent_parallel_id and parent_parallel_start_node_id:
|
|
|
- run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
|
|
|
- run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
|
|
|
- parent_parallel_start_node_id
|
|
|
- )
|
|
|
-
|
|
|
- yield NodeRunSucceededEvent(
|
|
|
+ break
|
|
|
+ elif isinstance(item, RunStreamChunkEvent):
|
|
|
+ yield NodeRunStreamChunkEvent(
|
|
|
id=node_instance.id,
|
|
|
node_id=node_instance.node_id,
|
|
|
node_type=node_instance.node_type,
|
|
|
node_data=node_instance.node_data,
|
|
|
+ chunk_content=item.chunk_content,
|
|
|
+ from_variable_selector=item.from_variable_selector,
|
|
|
route_node_state=route_node_state,
|
|
|
parallel_id=parallel_id,
|
|
|
parallel_start_node_id=parallel_start_node_id,
|
|
|
parent_parallel_id=parent_parallel_id,
|
|
|
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
|
|
)
|
|
|
-
|
|
|
- break
|
|
|
- elif isinstance(item, RunStreamChunkEvent):
|
|
|
- yield NodeRunStreamChunkEvent(
|
|
|
- id=node_instance.id,
|
|
|
- node_id=node_instance.node_id,
|
|
|
- node_type=node_instance.node_type,
|
|
|
- node_data=node_instance.node_data,
|
|
|
- chunk_content=item.chunk_content,
|
|
|
- from_variable_selector=item.from_variable_selector,
|
|
|
- route_node_state=route_node_state,
|
|
|
- parallel_id=parallel_id,
|
|
|
- parallel_start_node_id=parallel_start_node_id,
|
|
|
- parent_parallel_id=parent_parallel_id,
|
|
|
- parent_parallel_start_node_id=parent_parallel_start_node_id,
|
|
|
- )
|
|
|
- elif isinstance(item, RunRetrieverResourceEvent):
|
|
|
- yield NodeRunRetrieverResourceEvent(
|
|
|
- id=node_instance.id,
|
|
|
- node_id=node_instance.node_id,
|
|
|
- node_type=node_instance.node_type,
|
|
|
- node_data=node_instance.node_data,
|
|
|
- retriever_resources=item.retriever_resources,
|
|
|
- context=item.context,
|
|
|
- route_node_state=route_node_state,
|
|
|
- parallel_id=parallel_id,
|
|
|
- parallel_start_node_id=parallel_start_node_id,
|
|
|
- parent_parallel_id=parent_parallel_id,
|
|
|
- parent_parallel_start_node_id=parent_parallel_start_node_id,
|
|
|
- )
|
|
|
- except GenerateTaskStoppedError:
|
|
|
- # trigger node run failed event
|
|
|
- route_node_state.status = RouteNodeState.Status.FAILED
|
|
|
- route_node_state.failed_reason = "Workflow stopped."
|
|
|
- yield NodeRunFailedEvent(
|
|
|
- error="Workflow stopped.",
|
|
|
- id=node_instance.id,
|
|
|
- node_id=node_instance.node_id,
|
|
|
- node_type=node_instance.node_type,
|
|
|
- node_data=node_instance.node_data,
|
|
|
- route_node_state=route_node_state,
|
|
|
- parallel_id=parallel_id,
|
|
|
- parallel_start_node_id=parallel_start_node_id,
|
|
|
- parent_parallel_id=parent_parallel_id,
|
|
|
- parent_parallel_start_node_id=parent_parallel_start_node_id,
|
|
|
- )
|
|
|
- return
|
|
|
- except Exception as e:
|
|
|
- logger.exception(f"Node {node_instance.node_data.title} run failed")
|
|
|
- raise e
|
|
|
- finally:
|
|
|
- db.session.close()
|
|
|
+ elif isinstance(item, RunRetrieverResourceEvent):
|
|
|
+ yield NodeRunRetrieverResourceEvent(
|
|
|
+ id=node_instance.id,
|
|
|
+ node_id=node_instance.node_id,
|
|
|
+ node_type=node_instance.node_type,
|
|
|
+ node_data=node_instance.node_data,
|
|
|
+ retriever_resources=item.retriever_resources,
|
|
|
+ context=item.context,
|
|
|
+ route_node_state=route_node_state,
|
|
|
+ parallel_id=parallel_id,
|
|
|
+ parallel_start_node_id=parallel_start_node_id,
|
|
|
+ parent_parallel_id=parent_parallel_id,
|
|
|
+ parent_parallel_start_node_id=parent_parallel_start_node_id,
|
|
|
+ )
|
|
|
+ except GenerateTaskStoppedError:
|
|
|
+ # trigger node run failed event
|
|
|
+ route_node_state.status = RouteNodeState.Status.FAILED
|
|
|
+ route_node_state.failed_reason = "Workflow stopped."
|
|
|
+ yield NodeRunFailedEvent(
|
|
|
+ error="Workflow stopped.",
|
|
|
+ id=node_instance.id,
|
|
|
+ node_id=node_instance.node_id,
|
|
|
+ node_type=node_instance.node_type,
|
|
|
+ node_data=node_instance.node_data,
|
|
|
+ route_node_state=route_node_state,
|
|
|
+ parallel_id=parallel_id,
|
|
|
+ parallel_start_node_id=parallel_start_node_id,
|
|
|
+ parent_parallel_id=parent_parallel_id,
|
|
|
+ parent_parallel_start_node_id=parent_parallel_start_node_id,
|
|
|
+ )
|
|
|
+ return
|
|
|
+ except Exception as e:
|
|
|
+ logger.exception(f"Node {node_instance.node_data.title} run failed")
|
|
|
+ raise e
|
|
|
+ finally:
|
|
|
+ db.session.close()
|
|
|
|
|
|
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
|
|
|
"""
|