Browse Source

feat(variable-handling): enhance variable and segment conversion (#10483)

-LAN- 5 months ago
parent
commit
70c2ec8ed5

+ 2 - 0
api/core/variables/__init__.py

@@ -17,6 +17,7 @@ from .segments import (
 from .types import SegmentType
 from .variables import (
     ArrayAnyVariable,
+    ArrayFileVariable,
     ArrayNumberVariable,
     ArrayObjectVariable,
     ArrayStringVariable,
@@ -58,4 +59,5 @@ __all__ = [
     "ArrayStringSegment",
     "FileSegment",
     "FileVariable",
+    "ArrayFileVariable",
 ]

+ 11 - 2
api/core/variables/variables.py

@@ -1,9 +1,13 @@
+from collections.abc import Sequence
+from uuid import uuid4
+
 from pydantic import Field
 
 from core.helper import encrypter
 
 from .segments import (
     ArrayAnySegment,
+    ArrayFileSegment,
     ArrayNumberSegment,
     ArrayObjectSegment,
     ArrayStringSegment,
@@ -24,11 +28,12 @@ class Variable(Segment):
     """
 
     id: str = Field(
-        default="",
-        description="Unique identity for variable. It's only used by environment variables now.",
+        default=lambda _: str(uuid4()),
+        description="Unique identity for variable.",
     )
     name: str
     description: str = Field(default="", description="Description of the variable.")
+    selector: Sequence[str] = Field(default_factory=list)
 
 
 class StringVariable(StringSegment, Variable):
@@ -78,3 +83,7 @@ class NoneVariable(NoneSegment, Variable):
 
 class FileVariable(FileSegment, Variable):
     pass
+
+
+class ArrayFileVariable(ArrayFileSegment, Variable):
+    pass

+ 6 - 3
api/core/workflow/entities/variable_pool.py

@@ -95,13 +95,16 @@ class VariablePool(BaseModel):
         if len(selector) < 2:
             raise ValueError("Invalid selector")
 
+        if isinstance(value, Variable):
+            variable = value
         if isinstance(value, Segment):
-            v = value
+            variable = variable_factory.segment_to_variable(segment=value, selector=selector)
         else:
-            v = variable_factory.build_segment(value)
+            segment = variable_factory.build_segment(value)
+            variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
 
         hash_key = hash(tuple(selector[1:]))
-        self.variable_dictionary[selector[0]][hash_key] = v
+        self.variable_dictionary[selector[0]][hash_key] = variable
 
     def get(self, selector: Sequence[str], /) -> Segment | None:
         """

+ 69 - 11
api/factories/variable_factory.py

@@ -1,34 +1,65 @@
-from collections.abc import Mapping
+from collections.abc import Mapping, Sequence
 from typing import Any
+from uuid import uuid4
 
 from configs import dify_config
 from core.file import File
-from core.variables import (
+from core.variables.exc import VariableError
+from core.variables.segments import (
     ArrayAnySegment,
     ArrayFileSegment,
     ArrayNumberSegment,
-    ArrayNumberVariable,
     ArrayObjectSegment,
-    ArrayObjectVariable,
     ArraySegment,
     ArrayStringSegment,
-    ArrayStringVariable,
     FileSegment,
     FloatSegment,
-    FloatVariable,
     IntegerSegment,
-    IntegerVariable,
     NoneSegment,
     ObjectSegment,
-    ObjectVariable,
-    SecretVariable,
     Segment,
-    SegmentType,
     StringSegment,
+)
+from core.variables.types import SegmentType
+from core.variables.variables import (
+    ArrayAnyVariable,
+    ArrayFileVariable,
+    ArrayNumberVariable,
+    ArrayObjectVariable,
+    ArrayStringVariable,
+    FileVariable,
+    FloatVariable,
+    IntegerVariable,
+    NoneVariable,
+    ObjectVariable,
+    SecretVariable,
     StringVariable,
     Variable,
 )
-from core.variables.exc import VariableError
+
+
+class InvalidSelectorError(ValueError):
+    pass
+
+
+class UnsupportedSegmentTypeError(Exception):
+    pass
+
+
+# Define the constant
+SEGMENT_TO_VARIABLE_MAP = {
+    StringSegment: StringVariable,
+    IntegerSegment: IntegerVariable,
+    FloatSegment: FloatVariable,
+    ObjectSegment: ObjectVariable,
+    FileSegment: FileVariable,
+    ArrayStringSegment: ArrayStringVariable,
+    ArrayNumberSegment: ArrayNumberVariable,
+    ArrayObjectSegment: ArrayObjectVariable,
+    ArrayFileSegment: ArrayFileVariable,
+    ArrayAnySegment: ArrayAnyVariable,
+    NoneSegment: NoneVariable,
+}
 
 
 def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
@@ -96,3 +127,30 @@ def build_segment(value: Any, /) -> Segment:
             case _:
                 raise ValueError(f"not supported value {value}")
     raise ValueError(f"not supported value {value}")
+
+
+def segment_to_variable(
+    *,
+    segment: Segment,
+    selector: Sequence[str],
+    id: str | None = None,
+    name: str | None = None,
+    description: str = "",
+) -> Variable:
+    if isinstance(segment, Variable):
+        return segment
+    name = name or selector[-1]
+    id = id or str(uuid4())
+
+    segment_type = type(segment)
+    if segment_type not in SEGMENT_TO_VARIABLE_MAP:
+        raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}")
+
+    variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
+    return variable_class(
+        id=id,
+        name=name,
+        description=description,
+        value=segment.value,
+        selector=selector,
+    )

+ 3 - 2
api/tests/unit_tests/core/app/segments/test_segment.py

@@ -1,5 +1,5 @@
 from core.helper import encrypter
-from core.variables import SecretVariable, StringSegment
+from core.variables import SecretVariable, StringVariable
 from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.enums import SystemVariableKey
 
@@ -54,4 +54,5 @@ def test_convert_variable_to_segment_group():
     segments_group = variable_pool.convert_template(template)
     assert segments_group.text == "fake-user-id"
     assert segments_group.log == "fake-user-id"
-    assert segments_group.value == [StringSegment(value="fake-user-id")]
+    assert isinstance(segments_group.value[0], StringVariable)
+    assert segments_group.value[0].value == "fake-user-id"