Browse Source

fix: Introduce ArrayVariable and update iteration node to handle it (#12001)

Signed-off-by: -LAN- <laipz8200@outlook.com>
-LAN- 4 months ago
parent
commit
9cfd1c67b6

+ 2 - 0
api/core/variables/__init__.py

@@ -21,6 +21,7 @@ from .variables import (
     ArrayNumberVariable,
     ArrayObjectVariable,
     ArrayStringVariable,
+    ArrayVariable,
     FileVariable,
     FloatVariable,
     IntegerVariable,
@@ -43,6 +44,7 @@ __all__ = [
     "ArraySegment",
     "ArrayStringSegment",
     "ArrayStringVariable",
+    "ArrayVariable",
     "FileSegment",
     "FileVariable",
     "FloatSegment",

+ 9 - 4
api/core/variables/variables.py

@@ -10,6 +10,7 @@ from .segments import (
     ArrayFileSegment,
     ArrayNumberSegment,
     ArrayObjectSegment,
+    ArraySegment,
     ArrayStringSegment,
     FileSegment,
     FloatSegment,
@@ -52,19 +53,23 @@ class ObjectVariable(ObjectSegment, Variable):
     pass
 
 
-class ArrayAnyVariable(ArrayAnySegment, Variable):
+class ArrayVariable(ArraySegment, Variable):
     pass
 
 
-class ArrayStringVariable(ArrayStringSegment, Variable):
+class ArrayAnyVariable(ArrayAnySegment, ArrayVariable):
     pass
 
 
-class ArrayNumberVariable(ArrayNumberSegment, Variable):
+class ArrayStringVariable(ArrayStringSegment, ArrayVariable):
     pass
 
 
-class ArrayObjectVariable(ArrayObjectSegment, Variable):
+class ArrayNumberVariable(ArrayNumberSegment, ArrayVariable):
+    pass
+
+
+class ArrayObjectVariable(ArrayObjectSegment, ArrayVariable):
     pass
 
 

+ 9 - 6
api/core/workflow/nodes/iteration/iteration_node.py

@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Optional, cast
 from flask import Flask, current_app
 
 from configs import dify_config
-from core.variables import IntegerVariable
+from core.variables import ArrayVariable, IntegerVariable, NoneVariable
 from core.workflow.entities.node_entities import (
     NodeRunMetadataKey,
     NodeRunResult,
@@ -75,12 +75,15 @@ class IterationNode(BaseNode[IterationNodeData]):
         """
         Run the node.
         """
-        iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
+        variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
 
-        if not iterator_list_segment:
-            raise IteratorVariableNotFoundError(f"Iterator variable {self.node_data.iterator_selector} not found")
+        if not variable:
+            raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found")
 
-        if len(iterator_list_segment.value) == 0:
+        if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable):
+            raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
+
+        if isinstance(variable, NoneVariable) or len(variable.value) == 0:
             yield RunCompletedEvent(
                 run_result=NodeRunResult(
                     status=WorkflowNodeExecutionStatus.SUCCEEDED,
@@ -89,7 +92,7 @@ class IterationNode(BaseNode[IterationNodeData]):
             )
             return
 
-        iterator_list_value = iterator_list_segment.to_object()
+        iterator_list_value = variable.to_object()
 
         if not isinstance(iterator_list_value, list):
             raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")