Prechádzať zdrojové kódy

fix: issue #10596 by making the iteration node outputs right (#11394)

Signed-off-by: yihong0618 <zouzou0208@gmail.com>
Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
yihong 4 mesiacov pred
rodič
commit
d9d5d35a77

+ 1 - 13
api/core/app/entities/queue_entities.py

@@ -2,7 +2,7 @@ from datetime import datetime
 from enum import Enum, StrEnum
 from typing import Any, Optional
 
-from pydantic import BaseModel, field_validator
+from pydantic import BaseModel
 
 from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
 from core.workflow.entities.node_entities import NodeRunMetadataKey
@@ -113,18 +113,6 @@ class QueueIterationNextEvent(AppQueueEvent):
     output: Optional[Any] = None  # output for the current iteration
     duration: Optional[float] = None
 
-    @field_validator("output", mode="before")
-    @classmethod
-    def set_output(cls, v):
-        """
-        Set output
-        """
-        if v is None:
-            return None
-        if isinstance(v, int | float | str | bool | dict | list):
-            return v
-        raise ValueError("output must be a valid type")
-
 
 class QueueIterationCompletedEvent(AppQueueEvent):
     """

+ 46 - 30
api/core/workflow/nodes/iteration/iteration_node.py

@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Optional, cast
 from flask import Flask, current_app
 
 from configs import dify_config
-from core.model_runtime.utils.encoders import jsonable_encoder
+from core.variables import IntegerVariable
 from core.workflow.entities.node_entities import (
     NodeRunMetadataKey,
     NodeRunResult,
@@ -155,18 +155,19 @@ class IterationNode(BaseNode[IterationNodeData]):
             iteration_node_data=self.node_data,
             index=0,
             pre_iteration_output=None,
+            duration=None,
         )
         iter_run_map: dict[str, float] = {}
         outputs: list[Any] = [None] * len(iterator_list_value)
         try:
             if self.node_data.is_parallel:
                 futures: list[Future] = []
-                q = Queue()
+                q: Queue = Queue()
                 thread_pool = GraphEngineThreadPool(max_workers=self.node_data.parallel_nums, max_submit_count=100)
                 for index, item in enumerate(iterator_list_value):
                     future: Future = thread_pool.submit(
                         self._run_single_iter_parallel,
-                        current_app._get_current_object(),
+                        current_app._get_current_object(),  # type: ignore
                         q,
                         iterator_list_value,
                         inputs,
@@ -181,6 +182,7 @@ class IterationNode(BaseNode[IterationNodeData]):
                     future.add_done_callback(thread_pool.task_done_callback)
                     futures.append(future)
                 succeeded_count = 0
+                empty_count = 0
                 while True:
                     try:
                         event = q.get(timeout=1)
@@ -208,17 +210,22 @@ class IterationNode(BaseNode[IterationNodeData]):
             else:
                 for _ in range(len(iterator_list_value)):
                     yield from self._run_single_iter(
-                        iterator_list_value,
-                        variable_pool,
-                        inputs,
-                        outputs,
-                        start_at,
-                        graph_engine,
-                        iteration_graph,
-                        iter_run_map,
+                        iterator_list_value=iterator_list_value,
+                        variable_pool=variable_pool,
+                        inputs=inputs,
+                        outputs=outputs,
+                        start_at=start_at,
+                        graph_engine=graph_engine,
+                        iteration_graph=iteration_graph,
+                        iter_run_map=iter_run_map,
                     )
             if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
                 outputs = [output for output in outputs if output is not None]
+
+            # Flatten the list of lists
+            if isinstance(outputs, list) and all(isinstance(output, list) for output in outputs):
+                outputs = [item for sublist in outputs for item in sublist]
+
             yield IterationRunSucceededEvent(
                 iteration_id=self.id,
                 iteration_node_id=self.node_id,
@@ -226,7 +233,7 @@ class IterationNode(BaseNode[IterationNodeData]):
                 iteration_node_data=self.node_data,
                 start_at=start_at,
                 inputs=inputs,
-                outputs={"output": jsonable_encoder(outputs)},
+                outputs={"output": outputs},
                 steps=len(iterator_list_value),
                 metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
             )
@@ -234,7 +241,7 @@ class IterationNode(BaseNode[IterationNodeData]):
             yield RunCompletedEvent(
                 run_result=NodeRunResult(
                     status=WorkflowNodeExecutionStatus.SUCCEEDED,
-                    outputs={"output": jsonable_encoder(outputs)},
+                    outputs={"output": outputs},
                     metadata={NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map},
                 )
             )
@@ -248,7 +255,7 @@ class IterationNode(BaseNode[IterationNodeData]):
                 iteration_node_data=self.node_data,
                 start_at=start_at,
                 inputs=inputs,
-                outputs={"output": jsonable_encoder(outputs)},
+                outputs={"output": outputs},
                 steps=len(iterator_list_value),
                 metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
                 error=str(e),
@@ -280,7 +287,7 @@ class IterationNode(BaseNode[IterationNodeData]):
         :param node_data: node data
         :return:
         """
-        variable_mapping = {
+        variable_mapping: dict[str, Sequence[str]] = {
             f"{node_id}.input_selector": node_data.iterator_selector,
         }
 
@@ -308,7 +315,7 @@ class IterationNode(BaseNode[IterationNodeData]):
                 sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
                     graph_config=graph_config, config=sub_node_config
                 )
-                sub_node_variable_mapping = cast(dict[str, list[str]], sub_node_variable_mapping)
+                sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping)
             except NotImplementedError:
                 sub_node_variable_mapping = {}
 
@@ -329,8 +336,12 @@ class IterationNode(BaseNode[IterationNodeData]):
         return variable_mapping
 
     def _handle_event_metadata(
-        self, event: BaseNodeEvent, iter_run_index: str, parallel_mode_run_id: str
-    ) -> NodeRunStartedEvent | BaseNodeEvent:
+        self,
+        *,
+        event: BaseNodeEvent | InNodeEvent,
+        iter_run_index: int,
+        parallel_mode_run_id: str | None,
+    ) -> NodeRunStartedEvent | BaseNodeEvent | InNodeEvent:
         """
         add iteration metadata to event.
         """
@@ -355,6 +366,7 @@ class IterationNode(BaseNode[IterationNodeData]):
 
     def _run_single_iter(
         self,
+        *,
         iterator_list_value: list[str],
         variable_pool: VariablePool,
         inputs: dict[str, list],
@@ -373,12 +385,12 @@ class IterationNode(BaseNode[IterationNodeData]):
         try:
             rst = graph_engine.run()
             # get current iteration index
-            current_index = variable_pool.get([self.node_id, "index"]).value
+            index_variable = variable_pool.get([self.node_id, "index"])
+            if not isinstance(index_variable, IntegerVariable):
+                raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found")
+            current_index = index_variable.value
             iteration_run_id = parallel_mode_run_id if parallel_mode_run_id is not None else f"{current_index}"
             next_index = int(current_index) + 1
-
-            if current_index is None:
-                raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found")
             for event in rst:
                 if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
                     event.in_iteration_id = self.node_id
@@ -391,7 +403,9 @@ class IterationNode(BaseNode[IterationNodeData]):
                     continue
 
                 if isinstance(event, NodeRunSucceededEvent):
-                    yield self._handle_event_metadata(event, current_index, parallel_mode_run_id)
+                    yield self._handle_event_metadata(
+                        event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
+                    )
                 elif isinstance(event, BaseGraphEvent):
                     if isinstance(event, GraphRunFailedEvent):
                         # iteration run failed
@@ -404,7 +418,7 @@ class IterationNode(BaseNode[IterationNodeData]):
                                 parallel_mode_run_id=parallel_mode_run_id,
                                 start_at=start_at,
                                 inputs=inputs,
-                                outputs={"output": jsonable_encoder(outputs)},
+                                outputs={"output": outputs},
                                 steps=len(iterator_list_value),
                                 metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
                                 error=event.error,
@@ -417,7 +431,7 @@ class IterationNode(BaseNode[IterationNodeData]):
                                 iteration_node_data=self.node_data,
                                 start_at=start_at,
                                 inputs=inputs,
-                                outputs={"output": jsonable_encoder(outputs)},
+                                outputs={"output": outputs},
                                 steps=len(iterator_list_value),
                                 metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
                                 error=event.error,
@@ -429,9 +443,11 @@ class IterationNode(BaseNode[IterationNodeData]):
                             )
                         )
                         return
-                else:
-                    event = cast(InNodeEvent, event)
-                    metadata_event = self._handle_event_metadata(event, current_index, parallel_mode_run_id)
+                elif isinstance(event, InNodeEvent):
+                    # event = cast(InNodeEvent, event)
+                    metadata_event = self._handle_event_metadata(
+                        event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
+                    )
                     if isinstance(event, NodeRunFailedEvent):
                         if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
                             yield NodeInIterationFailedEvent(
@@ -513,7 +529,7 @@ class IterationNode(BaseNode[IterationNodeData]):
                 iteration_node_data=self.node_data,
                 index=next_index,
                 parallel_mode_run_id=parallel_mode_run_id,
-                pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None,
+                pre_iteration_output=current_iteration_output or None,
                 duration=duration,
             )
 
@@ -551,7 +567,7 @@ class IterationNode(BaseNode[IterationNodeData]):
         index: int,
         item: Any,
         iter_run_map: dict[str, float],
-    ) -> Generator[NodeEvent | InNodeEvent, None, None]:
+    ):
         """
         run single iteration in parallel mode
         """