|
@@ -304,11 +304,14 @@ class Graph(BaseModel):
|
|
|
parallel = None
|
|
|
if len(target_node_edges) > 1:
|
|
|
# fetch all node ids in current parallels
|
|
|
- parallel_branch_node_ids = []
|
|
|
+ parallel_branch_node_ids = {}
|
|
|
condition_edge_mappings = {}
|
|
|
for graph_edge in target_node_edges:
|
|
|
if graph_edge.run_condition is None:
|
|
|
- parallel_branch_node_ids.append(graph_edge.target_node_id)
|
|
|
+ if "default" not in parallel_branch_node_ids:
|
|
|
+ parallel_branch_node_ids["default"] = []
|
|
|
+
|
|
|
+ parallel_branch_node_ids["default"].append(graph_edge.target_node_id)
|
|
|
else:
|
|
|
condition_hash = graph_edge.run_condition.hash
|
|
|
if not condition_hash in condition_edge_mappings:
|
|
@@ -316,120 +319,182 @@ class Graph(BaseModel):
|
|
|
|
|
|
condition_edge_mappings[condition_hash].append(graph_edge)
|
|
|
|
|
|
- for _, graph_edges in condition_edge_mappings.items():
|
|
|
+ for condition_hash, graph_edges in condition_edge_mappings.items():
|
|
|
if len(graph_edges) > 1:
|
|
|
+ if condition_hash not in parallel_branch_node_ids:
|
|
|
+ parallel_branch_node_ids[condition_hash] = []
|
|
|
+
|
|
|
for graph_edge in graph_edges:
|
|
|
- parallel_branch_node_ids.append(graph_edge.target_node_id)
|
|
|
+ parallel_branch_node_ids[condition_hash].append(graph_edge.target_node_id)
|
|
|
+
|
|
|
+ condition_parallels = {}
|
|
|
+ for condition_hash, condition_parallel_branch_node_ids in parallel_branch_node_ids.items():
|
|
|
+ # any target node id in node_parallel_mapping
|
|
|
+ parallel = None
|
|
|
+ if condition_parallel_branch_node_ids:
|
|
|
+ parent_parallel_id = parent_parallel.id if parent_parallel else None
|
|
|
+
|
|
|
+ parallel = GraphParallel(
|
|
|
+ start_from_node_id=start_node_id,
|
|
|
+ parent_parallel_id=parent_parallel.id if parent_parallel else None,
|
|
|
+ parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None,
|
|
|
+ )
|
|
|
+ parallel_mapping[parallel.id] = parallel
|
|
|
+ condition_parallels[condition_hash] = parallel
|
|
|
+
|
|
|
+ in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
|
|
|
+ edge_mapping=edge_mapping,
|
|
|
+ reverse_edge_mapping=reverse_edge_mapping,
|
|
|
+ parallel_branch_node_ids=condition_parallel_branch_node_ids,
|
|
|
+ )
|
|
|
+
|
|
|
+ # collect all branches node ids
|
|
|
+ parallel_node_ids = []
|
|
|
+ for _, node_ids in in_branch_node_ids.items():
|
|
|
+ for node_id in node_ids:
|
|
|
+ in_parent_parallel = True
|
|
|
+ if parent_parallel_id:
|
|
|
+ in_parent_parallel = False
|
|
|
+ for parallel_node_id, parallel_id in node_parallel_mapping.items():
|
|
|
+ if parallel_id == parent_parallel_id and parallel_node_id == node_id:
|
|
|
+ in_parent_parallel = True
|
|
|
+ break
|
|
|
+
|
|
|
+ if in_parent_parallel:
|
|
|
+ parallel_node_ids.append(node_id)
|
|
|
+ node_parallel_mapping[node_id] = parallel.id
|
|
|
+
|
|
|
+ outside_parallel_target_node_ids = set()
|
|
|
+ for node_id in parallel_node_ids:
|
|
|
+ if node_id == parallel.start_from_node_id:
|
|
|
+ continue
|
|
|
+
|
|
|
+ node_edges = edge_mapping.get(node_id)
|
|
|
+ if not node_edges:
|
|
|
+ continue
|
|
|
+
|
|
|
+ if len(node_edges) > 1:
|
|
|
+ continue
|
|
|
|
|
|
- # any target node id in node_parallel_mapping
|
|
|
- if parallel_branch_node_ids:
|
|
|
- parent_parallel_id = parent_parallel.id if parent_parallel else None
|
|
|
+ target_node_id = node_edges[0].target_node_id
|
|
|
+ if target_node_id in parallel_node_ids:
|
|
|
+ continue
|
|
|
|
|
|
- parallel = GraphParallel(
|
|
|
- start_from_node_id=start_node_id,
|
|
|
- parent_parallel_id=parent_parallel.id if parent_parallel else None,
|
|
|
- parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None,
|
|
|
+ if parent_parallel_id:
|
|
|
+ parent_parallel = parallel_mapping.get(parent_parallel_id)
|
|
|
+ if not parent_parallel:
|
|
|
+ continue
|
|
|
+
|
|
|
+ if (
|
|
|
+ (
|
|
|
+ node_parallel_mapping.get(target_node_id)
|
|
|
+ and node_parallel_mapping.get(target_node_id) == parent_parallel_id
|
|
|
+ )
|
|
|
+ or (
|
|
|
+ parent_parallel
|
|
|
+ and parent_parallel.end_to_node_id
|
|
|
+ and target_node_id == parent_parallel.end_to_node_id
|
|
|
+ )
|
|
|
+ or (not node_parallel_mapping.get(target_node_id) and not parent_parallel)
|
|
|
+ ):
|
|
|
+ outside_parallel_target_node_ids.add(target_node_id)
|
|
|
+
|
|
|
+ if len(outside_parallel_target_node_ids) == 1:
|
|
|
+ if (
|
|
|
+ parent_parallel
|
|
|
+ and parent_parallel.end_to_node_id
|
|
|
+ and parallel.end_to_node_id == parent_parallel.end_to_node_id
|
|
|
+ ):
|
|
|
+ parallel.end_to_node_id = None
|
|
|
+ else:
|
|
|
+ parallel.end_to_node_id = outside_parallel_target_node_ids.pop()
|
|
|
+
|
|
|
+ if condition_edge_mappings:
|
|
|
+ for condition_hash, graph_edges in condition_edge_mappings.items():
|
|
|
+ current_parallel = cls._get_current_parallel(
|
|
|
+ parallel_mapping=parallel_mapping,
|
|
|
+ graph_edge=graph_edge,
|
|
|
+ parallel=condition_parallels.get(condition_hash),
|
|
|
+ parent_parallel=parent_parallel,
|
|
|
+ )
|
|
|
+
|
|
|
+ cls._recursively_add_parallels(
|
|
|
+ edge_mapping=edge_mapping,
|
|
|
+ reverse_edge_mapping=reverse_edge_mapping,
|
|
|
+ start_node_id=graph_edge.target_node_id,
|
|
|
+ parallel_mapping=parallel_mapping,
|
|
|
+ node_parallel_mapping=node_parallel_mapping,
|
|
|
+ parent_parallel=current_parallel,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ for graph_edge in target_node_edges:
|
|
|
+ current_parallel = cls._get_current_parallel(
|
|
|
+ parallel_mapping=parallel_mapping,
|
|
|
+ graph_edge=graph_edge,
|
|
|
+ parallel=parallel,
|
|
|
+ parent_parallel=parent_parallel,
|
|
|
+ )
|
|
|
+
|
|
|
+ cls._recursively_add_parallels(
|
|
|
+ edge_mapping=edge_mapping,
|
|
|
+ reverse_edge_mapping=reverse_edge_mapping,
|
|
|
+ start_node_id=graph_edge.target_node_id,
|
|
|
+ parallel_mapping=parallel_mapping,
|
|
|
+ node_parallel_mapping=node_parallel_mapping,
|
|
|
+ parent_parallel=current_parallel,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ for graph_edge in target_node_edges:
|
|
|
+ current_parallel = cls._get_current_parallel(
|
|
|
+ parallel_mapping=parallel_mapping,
|
|
|
+ graph_edge=graph_edge,
|
|
|
+ parallel=parallel,
|
|
|
+ parent_parallel=parent_parallel,
|
|
|
)
|
|
|
- parallel_mapping[parallel.id] = parallel
|
|
|
|
|
|
- in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
|
|
|
+ cls._recursively_add_parallels(
|
|
|
edge_mapping=edge_mapping,
|
|
|
reverse_edge_mapping=reverse_edge_mapping,
|
|
|
- parallel_branch_node_ids=parallel_branch_node_ids,
|
|
|
+ start_node_id=graph_edge.target_node_id,
|
|
|
+ parallel_mapping=parallel_mapping,
|
|
|
+ node_parallel_mapping=node_parallel_mapping,
|
|
|
+ parent_parallel=current_parallel,
|
|
|
)
|
|
|
|
|
|
- # collect all branches node ids
|
|
|
- parallel_node_ids = []
|
|
|
- for _, node_ids in in_branch_node_ids.items():
|
|
|
- for node_id in node_ids:
|
|
|
- in_parent_parallel = True
|
|
|
- if parent_parallel_id:
|
|
|
- in_parent_parallel = False
|
|
|
- for parallel_node_id, parallel_id in node_parallel_mapping.items():
|
|
|
- if parallel_id == parent_parallel_id and parallel_node_id == node_id:
|
|
|
- in_parent_parallel = True
|
|
|
- break
|
|
|
-
|
|
|
- if in_parent_parallel:
|
|
|
- parallel_node_ids.append(node_id)
|
|
|
- node_parallel_mapping[node_id] = parallel.id
|
|
|
-
|
|
|
- outside_parallel_target_node_ids = set()
|
|
|
- for node_id in parallel_node_ids:
|
|
|
- if node_id == parallel.start_from_node_id:
|
|
|
- continue
|
|
|
-
|
|
|
- node_edges = edge_mapping.get(node_id)
|
|
|
- if not node_edges:
|
|
|
- continue
|
|
|
-
|
|
|
- if len(node_edges) > 1:
|
|
|
- continue
|
|
|
-
|
|
|
- target_node_id = node_edges[0].target_node_id
|
|
|
- if target_node_id in parallel_node_ids:
|
|
|
- continue
|
|
|
-
|
|
|
- if parent_parallel_id:
|
|
|
- parent_parallel = parallel_mapping.get(parent_parallel_id)
|
|
|
- if not parent_parallel:
|
|
|
- continue
|
|
|
-
|
|
|
- if (
|
|
|
- (
|
|
|
- node_parallel_mapping.get(target_node_id)
|
|
|
- and node_parallel_mapping.get(target_node_id) == parent_parallel_id
|
|
|
- )
|
|
|
+ @classmethod
|
|
|
+ def _get_current_parallel(
|
|
|
+ cls,
|
|
|
+ parallel_mapping: dict[str, GraphParallel],
|
|
|
+ graph_edge: GraphEdge,
|
|
|
+ parallel: Optional[GraphParallel] = None,
|
|
|
+ parent_parallel: Optional[GraphParallel] = None,
|
|
|
+ ) -> Optional[GraphParallel]:
|
|
|
+ """
|
|
|
+ Get current parallel
|
|
|
+ """
|
|
|
+ current_parallel = None
|
|
|
+ if parallel:
|
|
|
+ current_parallel = parallel
|
|
|
+ elif parent_parallel:
|
|
|
+ if not parent_parallel.end_to_node_id or (
|
|
|
+ parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id
|
|
|
+ ):
|
|
|
+ current_parallel = parent_parallel
|
|
|
+ else:
|
|
|
+ # fetch parent parallel's parent parallel
|
|
|
+ parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id
|
|
|
+ if parent_parallel_parent_parallel_id:
|
|
|
+ parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id)
|
|
|
+ if parent_parallel_parent_parallel and (
|
|
|
+ not parent_parallel_parent_parallel.end_to_node_id
|
|
|
or (
|
|
|
- parent_parallel
|
|
|
- and parent_parallel.end_to_node_id
|
|
|
- and target_node_id == parent_parallel.end_to_node_id
|
|
|
+ parent_parallel_parent_parallel.end_to_node_id
|
|
|
+ and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id
|
|
|
)
|
|
|
- or (not node_parallel_mapping.get(target_node_id) and not parent_parallel)
|
|
|
):
|
|
|
- outside_parallel_target_node_ids.add(target_node_id)
|
|
|
+ current_parallel = parent_parallel_parent_parallel
|
|
|
|
|
|
- if len(outside_parallel_target_node_ids) == 1:
|
|
|
- if (
|
|
|
- parent_parallel
|
|
|
- and parent_parallel.end_to_node_id
|
|
|
- and parallel.end_to_node_id == parent_parallel.end_to_node_id
|
|
|
- ):
|
|
|
- parallel.end_to_node_id = None
|
|
|
- else:
|
|
|
- parallel.end_to_node_id = outside_parallel_target_node_ids.pop()
|
|
|
-
|
|
|
- for graph_edge in target_node_edges:
|
|
|
- current_parallel = None
|
|
|
- if parallel:
|
|
|
- current_parallel = parallel
|
|
|
- elif parent_parallel:
|
|
|
- if not parent_parallel.end_to_node_id or (
|
|
|
- parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id
|
|
|
- ):
|
|
|
- current_parallel = parent_parallel
|
|
|
- else:
|
|
|
- # fetch parent parallel's parent parallel
|
|
|
- parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id
|
|
|
- if parent_parallel_parent_parallel_id:
|
|
|
- parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id)
|
|
|
- if parent_parallel_parent_parallel and (
|
|
|
- not parent_parallel_parent_parallel.end_to_node_id
|
|
|
- or (
|
|
|
- parent_parallel_parent_parallel.end_to_node_id
|
|
|
- and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id
|
|
|
- )
|
|
|
- ):
|
|
|
- current_parallel = parent_parallel_parent_parallel
|
|
|
-
|
|
|
- cls._recursively_add_parallels(
|
|
|
- edge_mapping=edge_mapping,
|
|
|
- reverse_edge_mapping=reverse_edge_mapping,
|
|
|
- start_node_id=graph_edge.target_node_id,
|
|
|
- parallel_mapping=parallel_mapping,
|
|
|
- node_parallel_mapping=node_parallel_mapping,
|
|
|
- parent_parallel=current_parallel,
|
|
|
- )
|
|
|
+ return current_parallel
|
|
|
|
|
|
@classmethod
|
|
|
def _check_exceed_parallel_limit(
|