|
@@ -20,11 +20,9 @@ from core.workflow.graph_engine.entities.event import (
|
|
|
NodeRunSucceededEvent,
|
|
|
)
|
|
|
from core.workflow.graph_engine.entities.graph import Graph
|
|
|
-from core.workflow.graph_engine.entities.run_condition import RunCondition
|
|
|
from core.workflow.nodes.base_node import BaseNode
|
|
|
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
|
|
|
from core.workflow.nodes.iteration.entities import IterationNodeData
|
|
|
-from core.workflow.utils.condition.entities import Condition
|
|
|
from models.workflow import WorkflowNodeExecutionStatus
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
@@ -68,38 +66,6 @@ class IterationNode(BaseNode):
|
|
|
if not iteration_graph:
|
|
|
raise ValueError("iteration graph not found")
|
|
|
|
|
|
- leaf_node_ids = iteration_graph.get_leaf_node_ids()
|
|
|
- iteration_leaf_node_ids = []
|
|
|
- for leaf_node_id in leaf_node_ids:
|
|
|
- node_config = iteration_graph.node_id_config_mapping.get(leaf_node_id)
|
|
|
- if not node_config:
|
|
|
- continue
|
|
|
-
|
|
|
- leaf_node_iteration_id = node_config.get("data", {}).get("iteration_id")
|
|
|
- if not leaf_node_iteration_id:
|
|
|
- continue
|
|
|
-
|
|
|
- if leaf_node_iteration_id != self.node_id:
|
|
|
- continue
|
|
|
-
|
|
|
- iteration_leaf_node_ids.append(leaf_node_id)
|
|
|
-
|
|
|
- # add condition of end nodes to root node
|
|
|
- iteration_graph.add_extra_edge(
|
|
|
- source_node_id=leaf_node_id,
|
|
|
- target_node_id=root_node_id,
|
|
|
- run_condition=RunCondition(
|
|
|
- type="condition",
|
|
|
- conditions=[
|
|
|
- Condition(
|
|
|
- variable_selector=[self.node_id, "index"],
|
|
|
- comparison_operator="<",
|
|
|
- value=str(len(iterator_list_value)),
|
|
|
- )
|
|
|
- ],
|
|
|
- ),
|
|
|
- )
|
|
|
-
|
|
|
variable_pool = self.graph_runtime_state.variable_pool
|
|
|
|
|
|
# append iteration variable (item, index) to variable pool
|
|
@@ -149,91 +115,90 @@ class IterationNode(BaseNode):
|
|
|
|
|
|
outputs: list[Any] = []
|
|
|
try:
|
|
|
- # run workflow
|
|
|
- rst = graph_engine.run()
|
|
|
- for event in rst:
|
|
|
- if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
|
|
|
- event.in_iteration_id = self.node_id
|
|
|
-
|
|
|
- if (
|
|
|
- isinstance(event, BaseNodeEvent)
|
|
|
- and event.node_type == NodeType.ITERATION_START
|
|
|
- and not isinstance(event, NodeRunStreamChunkEvent)
|
|
|
- ):
|
|
|
- continue
|
|
|
-
|
|
|
- if isinstance(event, NodeRunSucceededEvent):
|
|
|
- if event.route_node_state.node_run_result:
|
|
|
- metadata = event.route_node_state.node_run_result.metadata
|
|
|
- if not metadata:
|
|
|
- metadata = {}
|
|
|
-
|
|
|
- if NodeRunMetadataKey.ITERATION_ID not in metadata:
|
|
|
- metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
|
|
|
- metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any(
|
|
|
- [self.node_id, "index"]
|
|
|
- )
|
|
|
- event.route_node_state.node_run_result.metadata = metadata
|
|
|
-
|
|
|
- yield event
|
|
|
-
|
|
|
- # handle iteration run result
|
|
|
- if event.route_node_state.node_id in iteration_leaf_node_ids:
|
|
|
- # append to iteration output variable list
|
|
|
- current_iteration_output = variable_pool.get_any(self.node_data.output_selector)
|
|
|
- outputs.append(current_iteration_output)
|
|
|
-
|
|
|
- # remove all nodes outputs from variable pool
|
|
|
- for node_id in iteration_graph.node_ids:
|
|
|
- variable_pool.remove_node(node_id)
|
|
|
-
|
|
|
- # move to next iteration
|
|
|
- current_index = variable_pool.get([self.node_id, "index"])
|
|
|
- if current_index is None:
|
|
|
- raise ValueError(f"iteration {self.node_id} current index not found")
|
|
|
-
|
|
|
- next_index = int(current_index.to_object()) + 1
|
|
|
- variable_pool.add([self.node_id, "index"], next_index)
|
|
|
-
|
|
|
- if next_index < len(iterator_list_value):
|
|
|
- variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
|
|
|
-
|
|
|
- yield IterationRunNextEvent(
|
|
|
- iteration_id=self.id,
|
|
|
- iteration_node_id=self.node_id,
|
|
|
- iteration_node_type=self.node_type,
|
|
|
- iteration_node_data=self.node_data,
|
|
|
- index=next_index,
|
|
|
- pre_iteration_output=jsonable_encoder(current_iteration_output)
|
|
|
- if current_iteration_output
|
|
|
- else None,
|
|
|
- )
|
|
|
- elif isinstance(event, BaseGraphEvent):
|
|
|
- if isinstance(event, GraphRunFailedEvent):
|
|
|
- # iteration run failed
|
|
|
- yield IterationRunFailedEvent(
|
|
|
- iteration_id=self.id,
|
|
|
- iteration_node_id=self.node_id,
|
|
|
- iteration_node_type=self.node_type,
|
|
|
- iteration_node_data=self.node_data,
|
|
|
- start_at=start_at,
|
|
|
- inputs=inputs,
|
|
|
- outputs={"output": jsonable_encoder(outputs)},
|
|
|
- steps=len(iterator_list_value),
|
|
|
- metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
|
|
- error=event.error,
|
|
|
- )
|
|
|
-
|
|
|
- yield RunCompletedEvent(
|
|
|
- run_result=NodeRunResult(
|
|
|
- status=WorkflowNodeExecutionStatus.FAILED,
|
|
|
+ for _ in range(len(iterator_list_value)):
|
|
|
+ # run workflow
|
|
|
+ rst = graph_engine.run()
|
|
|
+ for event in rst:
|
|
|
+ if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
|
|
|
+ event.in_iteration_id = self.node_id
|
|
|
+
|
|
|
+ if (
|
|
|
+ isinstance(event, BaseNodeEvent)
|
|
|
+ and event.node_type == NodeType.ITERATION_START
|
|
|
+ and not isinstance(event, NodeRunStreamChunkEvent)
|
|
|
+ ):
|
|
|
+ continue
|
|
|
+
|
|
|
+ if isinstance(event, NodeRunSucceededEvent):
|
|
|
+ if event.route_node_state.node_run_result:
|
|
|
+ metadata = event.route_node_state.node_run_result.metadata
|
|
|
+ if not metadata:
|
|
|
+ metadata = {}
|
|
|
+
|
|
|
+ if NodeRunMetadataKey.ITERATION_ID not in metadata:
|
|
|
+ metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
|
|
|
+ metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any(
|
|
|
+ [self.node_id, "index"]
|
|
|
+ )
|
|
|
+ event.route_node_state.node_run_result.metadata = metadata
|
|
|
+
|
|
|
+ yield event
|
|
|
+ elif isinstance(event, BaseGraphEvent):
|
|
|
+ if isinstance(event, GraphRunFailedEvent):
|
|
|
+ # iteration run failed
|
|
|
+ yield IterationRunFailedEvent(
|
|
|
+ iteration_id=self.id,
|
|
|
+ iteration_node_id=self.node_id,
|
|
|
+ iteration_node_type=self.node_type,
|
|
|
+ iteration_node_data=self.node_data,
|
|
|
+ start_at=start_at,
|
|
|
+ inputs=inputs,
|
|
|
+ outputs={"output": jsonable_encoder(outputs)},
|
|
|
+ steps=len(iterator_list_value),
|
|
|
+ metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
|
|
error=event.error,
|
|
|
)
|
|
|
- )
|
|
|
- break
|
|
|
- else:
|
|
|
- event = cast(InNodeEvent, event)
|
|
|
- yield event
|
|
|
+
|
|
|
+ yield RunCompletedEvent(
|
|
|
+ run_result=NodeRunResult(
|
|
|
+ status=WorkflowNodeExecutionStatus.FAILED,
|
|
|
+ error=event.error,
|
|
|
+ )
|
|
|
+ )
|
|
|
+ return
|
|
|
+ else:
|
|
|
+ event = cast(InNodeEvent, event)
|
|
|
+ yield event
|
|
|
+
|
|
|
+ # append to iteration output variable list
|
|
|
+ current_iteration_output = variable_pool.get_any(self.node_data.output_selector)
|
|
|
+ outputs.append(current_iteration_output)
|
|
|
+
|
|
|
+ # remove all nodes outputs from variable pool
|
|
|
+ for node_id in iteration_graph.node_ids:
|
|
|
+ variable_pool.remove_node(node_id)
|
|
|
+
|
|
|
+ # move to next iteration
|
|
|
+ current_index = variable_pool.get([self.node_id, "index"])
|
|
|
+ if current_index is None:
|
|
|
+ raise ValueError(f"iteration {self.node_id} current index not found")
|
|
|
+
|
|
|
+ next_index = int(current_index.to_object()) + 1
|
|
|
+ variable_pool.add([self.node_id, "index"], next_index)
|
|
|
+
|
|
|
+ if next_index < len(iterator_list_value):
|
|
|
+ variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
|
|
|
+
|
|
|
+ yield IterationRunNextEvent(
|
|
|
+ iteration_id=self.id,
|
|
|
+ iteration_node_id=self.node_id,
|
|
|
+ iteration_node_type=self.node_type,
|
|
|
+ iteration_node_data=self.node_data,
|
|
|
+ index=next_index,
|
|
|
+ pre_iteration_output=jsonable_encoder(current_iteration_output)
|
|
|
+ if current_iteration_output
|
|
|
+ else None,
|
|
|
+ )
|
|
|
|
|
|
yield IterationRunSucceededEvent(
|
|
|
iteration_id=self.id,
|