فهرست منبع

chore(workflow): Optimize the iteration when selecting a variable from a branch in the output variable causes iteration index err (#8440)

takatost 7 ماه پیش
والد
کامیت
88c9834ef2
2فایلهای تغییر یافته به همراه84 افزوده شده و 131 حذف شده
  1. 2 14
      api/core/workflow/graph_engine/entities/graph.py
  2. 82 117
      api/core/workflow/nodes/iteration/iteration_node.py

+ 2 - 14
api/core/workflow/graph_engine/entities/graph.py

@@ -689,23 +689,11 @@ class Graph(BaseModel):
 
                     parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id)
 
-        parallel_start_node_id = None
-        for p_start_node_id, branch_node_ids in parallel_start_node_ids.items():
+        for _, branch_node_ids in parallel_start_node_ids.items():
             if set(branch_node_ids) == set(routes_node_ids.keys()):
-                parallel_start_node_id = p_start_node_id
                 return True
 
-        if not parallel_start_node_id:
-            raise Exception("Parallel start node id not found")
-
-        for graph_edge in reverse_edge_mapping[start_node_id]:
-            if (
-                graph_edge.source_node_id not in all_routes_node_ids
-                or graph_edge.source_node_id != parallel_start_node_id
-            ):
-                return False
-
-        return True
+        return False
 
     @classmethod
     def _is_node2_after_node1(cls, node1_id: str, node2_id: str, edge_mapping: dict[str, list[GraphEdge]]) -> bool:

+ 82 - 117
api/core/workflow/nodes/iteration/iteration_node.py

@@ -20,11 +20,9 @@ from core.workflow.graph_engine.entities.event import (
     NodeRunSucceededEvent,
 )
 from core.workflow.graph_engine.entities.graph import Graph
-from core.workflow.graph_engine.entities.run_condition import RunCondition
 from core.workflow.nodes.base_node import BaseNode
 from core.workflow.nodes.event import RunCompletedEvent, RunEvent
 from core.workflow.nodes.iteration.entities import IterationNodeData
-from core.workflow.utils.condition.entities import Condition
 from models.workflow import WorkflowNodeExecutionStatus
 
 logger = logging.getLogger(__name__)
@@ -68,38 +66,6 @@ class IterationNode(BaseNode):
         if not iteration_graph:
             raise ValueError("iteration graph not found")
 
-        leaf_node_ids = iteration_graph.get_leaf_node_ids()
-        iteration_leaf_node_ids = []
-        for leaf_node_id in leaf_node_ids:
-            node_config = iteration_graph.node_id_config_mapping.get(leaf_node_id)
-            if not node_config:
-                continue
-
-            leaf_node_iteration_id = node_config.get("data", {}).get("iteration_id")
-            if not leaf_node_iteration_id:
-                continue
-
-            if leaf_node_iteration_id != self.node_id:
-                continue
-
-            iteration_leaf_node_ids.append(leaf_node_id)
-
-            # add condition of end nodes to root node
-            iteration_graph.add_extra_edge(
-                source_node_id=leaf_node_id,
-                target_node_id=root_node_id,
-                run_condition=RunCondition(
-                    type="condition",
-                    conditions=[
-                        Condition(
-                            variable_selector=[self.node_id, "index"],
-                            comparison_operator="<",
-                            value=str(len(iterator_list_value)),
-                        )
-                    ],
-                ),
-            )
-
         variable_pool = self.graph_runtime_state.variable_pool
 
         # append iteration variable (item, index) to variable pool
@@ -149,91 +115,90 @@ class IterationNode(BaseNode):
 
         outputs: list[Any] = []
         try:
-            # run workflow
-            rst = graph_engine.run()
-            for event in rst:
-                if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
-                    event.in_iteration_id = self.node_id
-
-                if (
-                    isinstance(event, BaseNodeEvent)
-                    and event.node_type == NodeType.ITERATION_START
-                    and not isinstance(event, NodeRunStreamChunkEvent)
-                ):
-                    continue
-
-                if isinstance(event, NodeRunSucceededEvent):
-                    if event.route_node_state.node_run_result:
-                        metadata = event.route_node_state.node_run_result.metadata
-                        if not metadata:
-                            metadata = {}
-
-                        if NodeRunMetadataKey.ITERATION_ID not in metadata:
-                            metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
-                            metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any(
-                                [self.node_id, "index"]
-                            )
-                            event.route_node_state.node_run_result.metadata = metadata
-
-                    yield event
-
-                    # handle iteration run result
-                    if event.route_node_state.node_id in iteration_leaf_node_ids:
-                        # append to iteration output variable list
-                        current_iteration_output = variable_pool.get_any(self.node_data.output_selector)
-                        outputs.append(current_iteration_output)
-
-                        # remove all nodes outputs from variable pool
-                        for node_id in iteration_graph.node_ids:
-                            variable_pool.remove_node(node_id)
-
-                        # move to next iteration
-                        current_index = variable_pool.get([self.node_id, "index"])
-                        if current_index is None:
-                            raise ValueError(f"iteration {self.node_id} current index not found")
-
-                        next_index = int(current_index.to_object()) + 1
-                        variable_pool.add([self.node_id, "index"], next_index)
-
-                        if next_index < len(iterator_list_value):
-                            variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
-
-                        yield IterationRunNextEvent(
-                            iteration_id=self.id,
-                            iteration_node_id=self.node_id,
-                            iteration_node_type=self.node_type,
-                            iteration_node_data=self.node_data,
-                            index=next_index,
-                            pre_iteration_output=jsonable_encoder(current_iteration_output)
-                            if current_iteration_output
-                            else None,
-                        )
-                elif isinstance(event, BaseGraphEvent):
-                    if isinstance(event, GraphRunFailedEvent):
-                        # iteration run failed
-                        yield IterationRunFailedEvent(
-                            iteration_id=self.id,
-                            iteration_node_id=self.node_id,
-                            iteration_node_type=self.node_type,
-                            iteration_node_data=self.node_data,
-                            start_at=start_at,
-                            inputs=inputs,
-                            outputs={"output": jsonable_encoder(outputs)},
-                            steps=len(iterator_list_value),
-                            metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
-                            error=event.error,
-                        )
-
-                        yield RunCompletedEvent(
-                            run_result=NodeRunResult(
-                                status=WorkflowNodeExecutionStatus.FAILED,
+            for _ in range(len(iterator_list_value)):
+                # run workflow
+                rst = graph_engine.run()
+                for event in rst:
+                    if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
+                        event.in_iteration_id = self.node_id
+
+                    if (
+                        isinstance(event, BaseNodeEvent)
+                        and event.node_type == NodeType.ITERATION_START
+                        and not isinstance(event, NodeRunStreamChunkEvent)
+                    ):
+                        continue
+
+                    if isinstance(event, NodeRunSucceededEvent):
+                        if event.route_node_state.node_run_result:
+                            metadata = event.route_node_state.node_run_result.metadata
+                            if not metadata:
+                                metadata = {}
+
+                            if NodeRunMetadataKey.ITERATION_ID not in metadata:
+                                metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
+                                metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any(
+                                    [self.node_id, "index"]
+                                )
+                                event.route_node_state.node_run_result.metadata = metadata
+
+                        yield event
+                    elif isinstance(event, BaseGraphEvent):
+                        if isinstance(event, GraphRunFailedEvent):
+                            # iteration run failed
+                            yield IterationRunFailedEvent(
+                                iteration_id=self.id,
+                                iteration_node_id=self.node_id,
+                                iteration_node_type=self.node_type,
+                                iteration_node_data=self.node_data,
+                                start_at=start_at,
+                                inputs=inputs,
+                                outputs={"output": jsonable_encoder(outputs)},
+                                steps=len(iterator_list_value),
+                                metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
                                 error=event.error,
                             )
-                        )
-                        break
-                else:
-                    event = cast(InNodeEvent, event)
-                    yield event
+
+                            yield RunCompletedEvent(
+                                run_result=NodeRunResult(
+                                    status=WorkflowNodeExecutionStatus.FAILED,
+                                    error=event.error,
+                                )
+                            )
+                            return
+                    else:
+                        event = cast(InNodeEvent, event)
+                        yield event
+
+                # append to iteration output variable list
+                current_iteration_output = variable_pool.get_any(self.node_data.output_selector)
+                outputs.append(current_iteration_output)
+
+                # remove all nodes outputs from variable pool
+                for node_id in iteration_graph.node_ids:
+                    variable_pool.remove_node(node_id)
+
+                # move to next iteration
+                current_index = variable_pool.get([self.node_id, "index"])
+                if current_index is None:
+                    raise ValueError(f"iteration {self.node_id} current index not found")
+
+                next_index = int(current_index.to_object()) + 1
+                variable_pool.add([self.node_id, "index"], next_index)
+
+                if next_index < len(iterator_list_value):
+                    variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
+
+                yield IterationRunNextEvent(
+                    iteration_id=self.id,
+                    iteration_node_id=self.node_id,
+                    iteration_node_type=self.node_type,
+                    iteration_node_data=self.node_data,
+                    index=next_index,
+                    pre_iteration_output=jsonable_encoder(current_iteration_output)
+                    if current_iteration_output
+                    else None,
+                )
 
             yield IterationRunSucceededEvent(
                 iteration_id=self.id,