|
@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Optional, cast
|
|
|
from flask import Flask, current_app
|
|
|
|
|
|
from configs import dify_config
|
|
|
-from core.model_runtime.utils.encoders import jsonable_encoder
|
|
|
+from core.variables import IntegerVariable
|
|
|
from core.workflow.entities.node_entities import (
|
|
|
NodeRunMetadataKey,
|
|
|
NodeRunResult,
|
|
@@ -155,18 +155,19 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|
|
iteration_node_data=self.node_data,
|
|
|
index=0,
|
|
|
pre_iteration_output=None,
|
|
|
+ duration=None,
|
|
|
)
|
|
|
iter_run_map: dict[str, float] = {}
|
|
|
outputs: list[Any] = [None] * len(iterator_list_value)
|
|
|
try:
|
|
|
if self.node_data.is_parallel:
|
|
|
futures: list[Future] = []
|
|
|
- q = Queue()
|
|
|
+ q: Queue = Queue()
|
|
|
thread_pool = GraphEngineThreadPool(max_workers=self.node_data.parallel_nums, max_submit_count=100)
|
|
|
for index, item in enumerate(iterator_list_value):
|
|
|
future: Future = thread_pool.submit(
|
|
|
self._run_single_iter_parallel,
|
|
|
- current_app._get_current_object(),
|
|
|
+ current_app._get_current_object(), # type: ignore
|
|
|
q,
|
|
|
iterator_list_value,
|
|
|
inputs,
|
|
@@ -181,6 +182,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|
|
future.add_done_callback(thread_pool.task_done_callback)
|
|
|
futures.append(future)
|
|
|
succeeded_count = 0
|
|
|
+ empty_count = 0
|
|
|
while True:
|
|
|
try:
|
|
|
event = q.get(timeout=1)
|
|
@@ -208,17 +210,22 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|
|
else:
|
|
|
for _ in range(len(iterator_list_value)):
|
|
|
yield from self._run_single_iter(
|
|
|
- iterator_list_value,
|
|
|
- variable_pool,
|
|
|
- inputs,
|
|
|
- outputs,
|
|
|
- start_at,
|
|
|
- graph_engine,
|
|
|
- iteration_graph,
|
|
|
- iter_run_map,
|
|
|
+ iterator_list_value=iterator_list_value,
|
|
|
+ variable_pool=variable_pool,
|
|
|
+ inputs=inputs,
|
|
|
+ outputs=outputs,
|
|
|
+ start_at=start_at,
|
|
|
+ graph_engine=graph_engine,
|
|
|
+ iteration_graph=iteration_graph,
|
|
|
+ iter_run_map=iter_run_map,
|
|
|
)
|
|
|
if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
|
|
outputs = [output for output in outputs if output is not None]
|
|
|
+
|
|
|
+ # Flatten the list of lists
|
|
|
+ if isinstance(outputs, list) and all(isinstance(output, list) for output in outputs):
|
|
|
+ outputs = [item for sublist in outputs for item in sublist]
|
|
|
+
|
|
|
yield IterationRunSucceededEvent(
|
|
|
iteration_id=self.id,
|
|
|
iteration_node_id=self.node_id,
|
|
@@ -226,7 +233,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|
|
iteration_node_data=self.node_data,
|
|
|
start_at=start_at,
|
|
|
inputs=inputs,
|
|
|
- outputs={"output": jsonable_encoder(outputs)},
|
|
|
+ outputs={"output": outputs},
|
|
|
steps=len(iterator_list_value),
|
|
|
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
|
|
)
|
|
@@ -234,7 +241,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|
|
yield RunCompletedEvent(
|
|
|
run_result=NodeRunResult(
|
|
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
|
|
- outputs={"output": jsonable_encoder(outputs)},
|
|
|
+ outputs={"output": outputs},
|
|
|
metadata={NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map},
|
|
|
)
|
|
|
)
|
|
@@ -248,7 +255,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|
|
iteration_node_data=self.node_data,
|
|
|
start_at=start_at,
|
|
|
inputs=inputs,
|
|
|
- outputs={"output": jsonable_encoder(outputs)},
|
|
|
+ outputs={"output": outputs},
|
|
|
steps=len(iterator_list_value),
|
|
|
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
|
|
error=str(e),
|
|
@@ -280,7 +287,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|
|
:param node_data: node data
|
|
|
:return:
|
|
|
"""
|
|
|
- variable_mapping = {
|
|
|
+ variable_mapping: dict[str, Sequence[str]] = {
|
|
|
f"{node_id}.input_selector": node_data.iterator_selector,
|
|
|
}
|
|
|
|
|
@@ -308,7 +315,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|
|
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
|
|
graph_config=graph_config, config=sub_node_config
|
|
|
)
|
|
|
- sub_node_variable_mapping = cast(dict[str, list[str]], sub_node_variable_mapping)
|
|
|
+ sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping)
|
|
|
except NotImplementedError:
|
|
|
sub_node_variable_mapping = {}
|
|
|
|
|
@@ -329,8 +336,12 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|
|
return variable_mapping
|
|
|
|
|
|
def _handle_event_metadata(
|
|
|
- self, event: BaseNodeEvent, iter_run_index: str, parallel_mode_run_id: str
|
|
|
- ) -> NodeRunStartedEvent | BaseNodeEvent:
|
|
|
+ self,
|
|
|
+ *,
|
|
|
+ event: BaseNodeEvent | InNodeEvent,
|
|
|
+ iter_run_index: int,
|
|
|
+ parallel_mode_run_id: str | None,
|
|
|
+ ) -> NodeRunStartedEvent | BaseNodeEvent | InNodeEvent:
|
|
|
"""
|
|
|
add iteration metadata to event.
|
|
|
"""
|
|
@@ -355,6 +366,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|
|
|
|
|
def _run_single_iter(
|
|
|
self,
|
|
|
+ *,
|
|
|
iterator_list_value: list[str],
|
|
|
variable_pool: VariablePool,
|
|
|
inputs: dict[str, list],
|
|
@@ -373,12 +385,12 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|
|
try:
|
|
|
rst = graph_engine.run()
|
|
|
# get current iteration index
|
|
|
- current_index = variable_pool.get([self.node_id, "index"]).value
|
|
|
+ index_variable = variable_pool.get([self.node_id, "index"])
|
|
|
+ if not isinstance(index_variable, IntegerVariable):
|
|
|
+ raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found")
|
|
|
+ current_index = index_variable.value
|
|
|
iteration_run_id = parallel_mode_run_id if parallel_mode_run_id is not None else f"{current_index}"
|
|
|
next_index = int(current_index) + 1
|
|
|
-
|
|
|
- if current_index is None:
|
|
|
- raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found")
|
|
|
for event in rst:
|
|
|
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
|
|
|
event.in_iteration_id = self.node_id
|
|
@@ -391,7 +403,9 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|
|
continue
|
|
|
|
|
|
if isinstance(event, NodeRunSucceededEvent):
|
|
|
- yield self._handle_event_metadata(event, current_index, parallel_mode_run_id)
|
|
|
+ yield self._handle_event_metadata(
|
|
|
+ event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
|
|
|
+ )
|
|
|
elif isinstance(event, BaseGraphEvent):
|
|
|
if isinstance(event, GraphRunFailedEvent):
|
|
|
# iteration run failed
|
|
@@ -404,7 +418,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|
|
parallel_mode_run_id=parallel_mode_run_id,
|
|
|
start_at=start_at,
|
|
|
inputs=inputs,
|
|
|
- outputs={"output": jsonable_encoder(outputs)},
|
|
|
+ outputs={"output": outputs},
|
|
|
steps=len(iterator_list_value),
|
|
|
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
|
|
error=event.error,
|
|
@@ -417,7 +431,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|
|
iteration_node_data=self.node_data,
|
|
|
start_at=start_at,
|
|
|
inputs=inputs,
|
|
|
- outputs={"output": jsonable_encoder(outputs)},
|
|
|
+ outputs={"output": outputs},
|
|
|
steps=len(iterator_list_value),
|
|
|
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
|
|
error=event.error,
|
|
@@ -429,9 +443,11 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|
|
)
|
|
|
)
|
|
|
return
|
|
|
- else:
|
|
|
- event = cast(InNodeEvent, event)
|
|
|
- metadata_event = self._handle_event_metadata(event, current_index, parallel_mode_run_id)
|
|
|
+ elif isinstance(event, InNodeEvent):
|
|
|
+ # event = cast(InNodeEvent, event)
|
|
|
+ metadata_event = self._handle_event_metadata(
|
|
|
+ event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
|
|
|
+ )
|
|
|
if isinstance(event, NodeRunFailedEvent):
|
|
|
if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
|
|
|
yield NodeInIterationFailedEvent(
|
|
@@ -513,7 +529,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|
|
iteration_node_data=self.node_data,
|
|
|
index=next_index,
|
|
|
parallel_mode_run_id=parallel_mode_run_id,
|
|
|
- pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None,
|
|
|
+ pre_iteration_output=current_iteration_output or None,
|
|
|
duration=duration,
|
|
|
)
|
|
|
|
|
@@ -551,7 +567,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|
|
index: int,
|
|
|
item: Any,
|
|
|
iter_run_map: dict[str, float],
|
|
|
- ) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
|
|
+ ):
|
|
|
"""
|
|
|
run single iteration in parallel mode
|
|
|
"""
|