Przeglądaj źródła

fix(workflow): in multi-parallel execution with multiple conditional branches (#8221)

takatost 7 miesięcy temu
rodzic
commit
f6dfe23cf8
1 zmienionych plików z 167 dodań i 102 usunięć
  1. 167 102
      api/core/workflow/graph_engine/entities/graph.py

+ 167 - 102
api/core/workflow/graph_engine/entities/graph.py

@@ -304,11 +304,14 @@ class Graph(BaseModel):
         parallel = None
         if len(target_node_edges) > 1:
             # fetch all node ids in current parallels
-            parallel_branch_node_ids = []
+            parallel_branch_node_ids = {}
             condition_edge_mappings = {}
             for graph_edge in target_node_edges:
                 if graph_edge.run_condition is None:
-                    parallel_branch_node_ids.append(graph_edge.target_node_id)
+                    if "default" not in parallel_branch_node_ids:
+                        parallel_branch_node_ids["default"] = []
+
+                    parallel_branch_node_ids["default"].append(graph_edge.target_node_id)
                 else:
                     condition_hash = graph_edge.run_condition.hash
                     if not condition_hash in condition_edge_mappings:
@@ -316,120 +319,182 @@ class Graph(BaseModel):
 
                     condition_edge_mappings[condition_hash].append(graph_edge)
 
-            for _, graph_edges in condition_edge_mappings.items():
+            for condition_hash, graph_edges in condition_edge_mappings.items():
                 if len(graph_edges) > 1:
+                    if condition_hash not in parallel_branch_node_ids:
+                        parallel_branch_node_ids[condition_hash] = []
+
                     for graph_edge in graph_edges:
-                        parallel_branch_node_ids.append(graph_edge.target_node_id)
+                        parallel_branch_node_ids[condition_hash].append(graph_edge.target_node_id)
+
+            condition_parallels = {}
+            for condition_hash, condition_parallel_branch_node_ids in parallel_branch_node_ids.items():
+                # any target node id in node_parallel_mapping
+                parallel = None
+                if condition_parallel_branch_node_ids:
+                    parent_parallel_id = parent_parallel.id if parent_parallel else None
+
+                    parallel = GraphParallel(
+                        start_from_node_id=start_node_id,
+                        parent_parallel_id=parent_parallel.id if parent_parallel else None,
+                        parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None,
+                    )
+                    parallel_mapping[parallel.id] = parallel
+                    condition_parallels[condition_hash] = parallel
+
+                    in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
+                        edge_mapping=edge_mapping,
+                        reverse_edge_mapping=reverse_edge_mapping,
+                        parallel_branch_node_ids=condition_parallel_branch_node_ids,
+                    )
+
+                    # collect all branches node ids
+                    parallel_node_ids = []
+                    for _, node_ids in in_branch_node_ids.items():
+                        for node_id in node_ids:
+                            in_parent_parallel = True
+                            if parent_parallel_id:
+                                in_parent_parallel = False
+                                for parallel_node_id, parallel_id in node_parallel_mapping.items():
+                                    if parallel_id == parent_parallel_id and parallel_node_id == node_id:
+                                        in_parent_parallel = True
+                                        break
+
+                            if in_parent_parallel:
+                                parallel_node_ids.append(node_id)
+                                node_parallel_mapping[node_id] = parallel.id
+
+                    outside_parallel_target_node_ids = set()
+                    for node_id in parallel_node_ids:
+                        if node_id == parallel.start_from_node_id:
+                            continue
+
+                        node_edges = edge_mapping.get(node_id)
+                        if not node_edges:
+                            continue
+
+                        if len(node_edges) > 1:
+                            continue
 
-            # any target node id in node_parallel_mapping
-            if parallel_branch_node_ids:
-                parent_parallel_id = parent_parallel.id if parent_parallel else None
+                        target_node_id = node_edges[0].target_node_id
+                        if target_node_id in parallel_node_ids:
+                            continue
 
-                parallel = GraphParallel(
-                    start_from_node_id=start_node_id,
-                    parent_parallel_id=parent_parallel.id if parent_parallel else None,
-                    parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None,
+                        if parent_parallel_id:
+                            parent_parallel = parallel_mapping.get(parent_parallel_id)
+                            if not parent_parallel:
+                                continue
+
+                        if (
+                            (
+                                node_parallel_mapping.get(target_node_id)
+                                and node_parallel_mapping.get(target_node_id) == parent_parallel_id
+                            )
+                            or (
+                                parent_parallel
+                                and parent_parallel.end_to_node_id
+                                and target_node_id == parent_parallel.end_to_node_id
+                            )
+                            or (not node_parallel_mapping.get(target_node_id) and not parent_parallel)
+                        ):
+                            outside_parallel_target_node_ids.add(target_node_id)
+
+                    if len(outside_parallel_target_node_ids) == 1:
+                        if (
+                            parent_parallel
+                            and parent_parallel.end_to_node_id
+                            and parallel.end_to_node_id == parent_parallel.end_to_node_id
+                        ):
+                            parallel.end_to_node_id = None
+                        else:
+                            parallel.end_to_node_id = outside_parallel_target_node_ids.pop()
+
+            if condition_edge_mappings:
+                for condition_hash, graph_edges in condition_edge_mappings.items():
+                    current_parallel = cls._get_current_parallel(
+                        parallel_mapping=parallel_mapping,
+                        graph_edge=graph_edge,
+                        parallel=condition_parallels.get(condition_hash),
+                        parent_parallel=parent_parallel,
+                    )
+
+                    cls._recursively_add_parallels(
+                        edge_mapping=edge_mapping,
+                        reverse_edge_mapping=reverse_edge_mapping,
+                        start_node_id=graph_edge.target_node_id,
+                        parallel_mapping=parallel_mapping,
+                        node_parallel_mapping=node_parallel_mapping,
+                        parent_parallel=current_parallel,
+                    )
+            else:
+                for graph_edge in target_node_edges:
+                    current_parallel = cls._get_current_parallel(
+                        parallel_mapping=parallel_mapping,
+                        graph_edge=graph_edge,
+                        parallel=parallel,
+                        parent_parallel=parent_parallel,
+                    )
+
+                    cls._recursively_add_parallels(
+                        edge_mapping=edge_mapping,
+                        reverse_edge_mapping=reverse_edge_mapping,
+                        start_node_id=graph_edge.target_node_id,
+                        parallel_mapping=parallel_mapping,
+                        node_parallel_mapping=node_parallel_mapping,
+                        parent_parallel=current_parallel,
+                    )
+        else:
+            for graph_edge in target_node_edges:
+                current_parallel = cls._get_current_parallel(
+                    parallel_mapping=parallel_mapping,
+                    graph_edge=graph_edge,
+                    parallel=parallel,
+                    parent_parallel=parent_parallel,
                 )
-                parallel_mapping[parallel.id] = parallel
 
-                in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
+                cls._recursively_add_parallels(
                     edge_mapping=edge_mapping,
                     reverse_edge_mapping=reverse_edge_mapping,
-                    parallel_branch_node_ids=parallel_branch_node_ids,
+                    start_node_id=graph_edge.target_node_id,
+                    parallel_mapping=parallel_mapping,
+                    node_parallel_mapping=node_parallel_mapping,
+                    parent_parallel=current_parallel,
                 )
 
-                # collect all branches node ids
-                parallel_node_ids = []
-                for _, node_ids in in_branch_node_ids.items():
-                    for node_id in node_ids:
-                        in_parent_parallel = True
-                        if parent_parallel_id:
-                            in_parent_parallel = False
-                            for parallel_node_id, parallel_id in node_parallel_mapping.items():
-                                if parallel_id == parent_parallel_id and parallel_node_id == node_id:
-                                    in_parent_parallel = True
-                                    break
-
-                        if in_parent_parallel:
-                            parallel_node_ids.append(node_id)
-                            node_parallel_mapping[node_id] = parallel.id
-
-                outside_parallel_target_node_ids = set()
-                for node_id in parallel_node_ids:
-                    if node_id == parallel.start_from_node_id:
-                        continue
-
-                    node_edges = edge_mapping.get(node_id)
-                    if not node_edges:
-                        continue
-
-                    if len(node_edges) > 1:
-                        continue
-
-                    target_node_id = node_edges[0].target_node_id
-                    if target_node_id in parallel_node_ids:
-                        continue
-
-                    if parent_parallel_id:
-                        parent_parallel = parallel_mapping.get(parent_parallel_id)
-                        if not parent_parallel:
-                            continue
-
-                    if (
-                        (
-                            node_parallel_mapping.get(target_node_id)
-                            and node_parallel_mapping.get(target_node_id) == parent_parallel_id
-                        )
+    @classmethod
+    def _get_current_parallel(
+        cls,
+        parallel_mapping: dict[str, GraphParallel],
+        graph_edge: GraphEdge,
+        parallel: Optional[GraphParallel] = None,
+        parent_parallel: Optional[GraphParallel] = None,
+    ) -> Optional[GraphParallel]:
+        """
+        Get current parallel
+        """
+        current_parallel = None
+        if parallel:
+            current_parallel = parallel
+        elif parent_parallel:
+            if not parent_parallel.end_to_node_id or (
+                parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id
+            ):
+                current_parallel = parent_parallel
+            else:
+                # fetch parent parallel's parent parallel
+                parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id
+                if parent_parallel_parent_parallel_id:
+                    parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id)
+                    if parent_parallel_parent_parallel and (
+                        not parent_parallel_parent_parallel.end_to_node_id
                         or (
-                            parent_parallel
-                            and parent_parallel.end_to_node_id
-                            and target_node_id == parent_parallel.end_to_node_id
+                            parent_parallel_parent_parallel.end_to_node_id
+                            and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id
                         )
-                        or (not node_parallel_mapping.get(target_node_id) and not parent_parallel)
                     ):
-                        outside_parallel_target_node_ids.add(target_node_id)
+                        current_parallel = parent_parallel_parent_parallel
 
-                if len(outside_parallel_target_node_ids) == 1:
-                    if (
-                        parent_parallel
-                        and parent_parallel.end_to_node_id
-                        and parallel.end_to_node_id == parent_parallel.end_to_node_id
-                    ):
-                        parallel.end_to_node_id = None
-                    else:
-                        parallel.end_to_node_id = outside_parallel_target_node_ids.pop()
-
-        for graph_edge in target_node_edges:
-            current_parallel = None
-            if parallel:
-                current_parallel = parallel
-            elif parent_parallel:
-                if not parent_parallel.end_to_node_id or (
-                    parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id
-                ):
-                    current_parallel = parent_parallel
-                else:
-                    # fetch parent parallel's parent parallel
-                    parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id
-                    if parent_parallel_parent_parallel_id:
-                        parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id)
-                        if parent_parallel_parent_parallel and (
-                            not parent_parallel_parent_parallel.end_to_node_id
-                            or (
-                                parent_parallel_parent_parallel.end_to_node_id
-                                and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id
-                            )
-                        ):
-                            current_parallel = parent_parallel_parent_parallel
-
-            cls._recursively_add_parallels(
-                edge_mapping=edge_mapping,
-                reverse_edge_mapping=reverse_edge_mapping,
-                start_node_id=graph_edge.target_node_id,
-                parallel_mapping=parallel_mapping,
-                node_parallel_mapping=node_parallel_mapping,
-                parent_parallel=current_parallel,
-            )
+        return current_parallel
 
     @classmethod
     def _check_exceed_parallel_limit(