Просмотр исходного кода

refactor(api/core/app/segments): Support more kinds of Segments. (#6706)

Signed-off-by: -LAN- <laipz8200@outlook.com>
-LAN- 9 месяцев назад
Родитель
Сommit
c6996a48a4

+ 16 - 1
api/core/app/segments/__init__.py

@@ -1,5 +1,14 @@
 from .segment_group import SegmentGroup
-from .segments import NoneSegment, Segment
+from .segments import (
+    ArraySegment,
+    FileSegment,
+    FloatSegment,
+    IntegerSegment,
+    NoneSegment,
+    ObjectSegment,
+    Segment,
+    StringSegment,
+)
 from .types import SegmentType
 from .variables import (
     ArrayVariable,
@@ -27,4 +36,10 @@ __all__ = [
     'Segment',
     'NoneSegment',
     'NoneVariable',
+    'IntegerSegment',
+    'FloatSegment',
+    'ObjectSegment',
+    'ArraySegment',
+    'FileSegment',
+    'StringSegment',
 ]

+ 20 - 21
api/core/app/segments/factory.py

@@ -3,15 +3,20 @@ from typing import Any
 
 from core.file.file_obj import FileVar
 
-from .segments import Segment, StringSegment
+from .segments import (
+    ArraySegment,
+    FileSegment,
+    FloatSegment,
+    IntegerSegment,
+    NoneSegment,
+    ObjectSegment,
+    Segment,
+    StringSegment,
+)
 from .types import SegmentType
 from .variables import (
-    ArrayVariable,
-    FileVariable,
     FloatVariable,
     IntegerVariable,
-    NoneVariable,
-    ObjectVariable,
     SecretVariable,
     StringVariable,
     Variable,
@@ -39,29 +44,23 @@ def build_variable_from_mapping(m: Mapping[str, Any], /) -> Variable:
     raise ValueError(f'not supported value type {value_type}')
 
 
-def build_anonymous_variable(value: Any, /) -> Variable:
+def build_segment(value: Any, /) -> Segment:
     if value is None:
-        return NoneVariable(name='anonymous')
+        return NoneSegment()
     if isinstance(value, str):
-        return StringVariable(name='anonymous', value=value)
+        return StringSegment(value=value)
     if isinstance(value, int):
-        return IntegerVariable(name='anonymous', value=value)
+        return IntegerSegment(value=value)
     if isinstance(value, float):
-        return FloatVariable(name='anonymous', value=value)
+        return FloatSegment(value=value)
     if isinstance(value, dict):
         # TODO: Limit the depth of the object
-        obj = {k: build_anonymous_variable(v) for k, v in value.items()}
-        return ObjectVariable(name='anonymous', value=obj)
+        obj = {k: build_segment(v) for k, v in value.items()}
+        return ObjectSegment(value=obj)
     if isinstance(value, list):
         # TODO: Limit the depth of the array
-        elements = [build_anonymous_variable(v) for v in value]
-        return ArrayVariable(name='anonymous', value=elements)
+        elements = [build_segment(v) for v in value]
+        return ArraySegment(value=elements)
     if isinstance(value, FileVar):
-        return FileVariable(name='anonymous', value=value)
-    raise ValueError(f'not supported value {value}')
-
-
-def build_segment(value: Any, /) -> Segment:
-    if isinstance(value, str):
-        return StringSegment(value=value)
+        return FileSegment(value=value)
     raise ValueError(f'not supported value {value}')

+ 3 - 2
api/core/app/segments/parser.py

@@ -1,8 +1,9 @@
 import re
 
-from core.app.segments import SegmentGroup, factory
 from core.workflow.entities.variable_pool import VariablePool
 
+from . import SegmentGroup, factory
+
 VARIABLE_PATTERN = re.compile(r'\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}')
 
 
@@ -14,4 +15,4 @@ def convert_template(*, template: str, variable_pool: VariablePool):
             segments.append(value)
         else:
             segments.append(factory.build_segment(part))
-    return SegmentGroup(segments=segments)
+    return SegmentGroup(value=segments)

+ 10 - 7
api/core/app/segments/segment_group.py

@@ -1,19 +1,22 @@
-from pydantic import BaseModel
-
 from .segments import Segment
+from .types import SegmentType
 
 
-class SegmentGroup(BaseModel):
-    segments: list[Segment]
+class SegmentGroup(Segment):
+    value_type: SegmentType = SegmentType.GROUP
+    value: list[Segment]
 
     @property
     def text(self):
-        return ''.join([segment.text for segment in self.segments])
+        return ''.join([segment.text for segment in self.value])
 
     @property
     def log(self):
-        return ''.join([segment.log for segment in self.segments])
+        return ''.join([segment.log for segment in self.value])
 
     @property
     def markdown(self):
-        return ''.join([segment.markdown for segment in self.segments])
+        return ''.join([segment.markdown for segment in self.value])
+
+    def to_object(self):
+        return [segment.to_object() for segment in self.value]

+ 58 - 0
api/core/app/segments/segments.py

@@ -1,7 +1,11 @@
+import json
+from collections.abc import Mapping, Sequence
 from typing import Any
 
 from pydantic import BaseModel, ConfigDict, field_validator
 
+from core.file.file_obj import FileVar
+
 from .types import SegmentType
 
 
@@ -57,3 +61,57 @@ class NoneSegment(Segment):
 class StringSegment(Segment):
     value_type: SegmentType = SegmentType.STRING
     value: str
+
+class FloatSegment(Segment):
+    value_type: SegmentType = SegmentType.NUMBER
+    value: float
+
+
+class IntegerSegment(Segment):
+    value_type: SegmentType = SegmentType.NUMBER
+    value: int
+
+
+class ObjectSegment(Segment):
+    value_type: SegmentType = SegmentType.OBJECT
+    value: Mapping[str, Segment]
+
+    @property
+    def text(self) -> str:
+        # TODO: Process variables.
+        return json.dumps(self.model_dump()['value'], ensure_ascii=False)
+
+    @property
+    def log(self) -> str:
+        # TODO: Process variables.
+        return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
+
+    @property
+    def markdown(self) -> str:
+        # TODO: Use markdown code block
+        return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
+
+    def to_object(self):
+        return {k: v.to_object() for k, v in self.value.items()}
+
+
+class ArraySegment(Segment):
+    value_type: SegmentType = SegmentType.ARRAY
+    value: Sequence[Segment]
+
+    @property
+    def markdown(self) -> str:
+        return '\n'.join(['- ' + item.markdown for item in self.value])
+
+    def to_object(self):
+        return [v.to_object() for v in self.value]
+
+
+class FileSegment(Segment):
+    value_type: SegmentType = SegmentType.FILE
+    # TODO: embed FileVar in this model.
+    value: FileVar
+
+    @property
+    def markdown(self) -> str:
+        return self.value.to_markdown()

+ 2 - 0
api/core/app/segments/types.py

@@ -9,3 +9,5 @@ class SegmentType(str, Enum):
     ARRAY = 'array'
     OBJECT = 'object'
     FILE = 'file'
+
+    GROUP = 'group'

+ 20 - 49
api/core/app/segments/variables.py

@@ -1,12 +1,18 @@
-import json
-from collections.abc import Mapping, Sequence
 
 from pydantic import Field
 
-from core.file.file_obj import FileVar
 from core.helper import encrypter
 
-from .segments import NoneSegment, Segment, StringSegment
+from .segments import (
+    ArraySegment,
+    FileSegment,
+    FloatSegment,
+    IntegerSegment,
+    NoneSegment,
+    ObjectSegment,
+    Segment,
+    StringSegment,
+)
 from .types import SegmentType
 
 
@@ -27,59 +33,24 @@ class StringVariable(StringSegment, Variable):
     pass
 
 
-class FloatVariable(Variable):
-    value_type: SegmentType = SegmentType.NUMBER
-    value: float
-
-
-class IntegerVariable(Variable):
-    value_type: SegmentType = SegmentType.NUMBER
-    value: int
-
-
-class ObjectVariable(Variable):
-    value_type: SegmentType = SegmentType.OBJECT
-    value: Mapping[str, Variable]
-
-    @property
-    def text(self) -> str:
-        # TODO: Process variables.
-        return json.dumps(self.model_dump()['value'], ensure_ascii=False)
-
-    @property
-    def log(self) -> str:
-        # TODO: Process variables.
-        return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
-
-    @property
-    def markdown(self) -> str:
-        # TODO: Use markdown code block
-        return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
+class FloatVariable(FloatSegment, Variable):
+    pass
 
-    def to_object(self):
-        return {k: v.to_object() for k, v in self.value.items()}
 
+class IntegerVariable(IntegerSegment, Variable):
+    pass
 
-class ArrayVariable(Variable):
-    value_type: SegmentType = SegmentType.ARRAY
-    value: Sequence[Variable]
 
-    @property
-    def markdown(self) -> str:
-        return '\n'.join(['- ' + item.markdown for item in self.value])
+class ObjectVariable(ObjectSegment, Variable):
+    pass
 
-    def to_object(self):
-        return [v.to_object() for v in self.value]
 
+class ArrayVariable(ArraySegment, Variable):
+    pass
 
-class FileVariable(Variable):
-    value_type: SegmentType = SegmentType.FILE
-    # TODO: embed FileVar in this model.
-    value: FileVar
 
-    @property
-    def markdown(self) -> str:
-        return self.value.to_markdown()
+class FileVariable(FileSegment, Variable):
+    pass
 
 
 class SecretVariable(StringVariable):

+ 6 - 6
api/core/workflow/entities/variable_pool.py

@@ -4,7 +4,7 @@ from typing import Any, Union
 
 from typing_extensions import deprecated
 
-from core.app.segments import Variable, factory
+from core.app.segments import Segment, Variable, factory
 from core.file.file_obj import FileVar
 from core.workflow.entities.node_entities import SystemVariable
 
@@ -33,7 +33,7 @@ class VariablePool:
         # The first element of the selector is the node id, it's the first-level key in the dictionary.
         # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
         # elements of the selector except the first one.
-        self._variable_dictionary: dict[str, dict[int, Variable]] = defaultdict(dict)
+        self._variable_dictionary: dict[str, dict[int, Segment]] = defaultdict(dict)
 
         # TODO: This user inputs is not used for pool.
         self.user_inputs = user_inputs
@@ -67,15 +67,15 @@ class VariablePool:
         if value is None:
             return
 
-        if not isinstance(value, Variable):
-            v = factory.build_anonymous_variable(value)
-        else:
+        if isinstance(value, Segment):
             v = value
+        else:
+            v = factory.build_segment(value)
 
         hash_key = hash(tuple(selector[1:]))
         self._variable_dictionary[selector[0]][hash_key] = v
 
-    def get(self, selector: Sequence[str], /) -> Variable | None:
+    def get(self, selector: Sequence[str], /) -> Segment | None:
         """
         Retrieves the value from the variable pool based on the given selector.
 

+ 1 - 0
api/core/workflow/nodes/tool/tool_node.py

@@ -126,6 +126,7 @@ class ToolNode(BaseNode):
             else:
                 tool_input = node_data.tool_parameters[parameter_name]
                 if tool_input.type == 'variable':
+                    # TODO: check if the variable exists in the variable pool
                     parameter_value = variable_pool.get(tool_input.value).value
                 else:
                     segment_group = parser.convert_template(

+ 0 - 0
api/tests/unit_tests/app/test_segment.py → api/tests/unit_tests/core/app/test_segment.py


+ 5 - 4
api/tests/unit_tests/app/test_variables.py → api/tests/unit_tests/core/app/test_variables.py

@@ -5,7 +5,8 @@ from core.app.segments import (
     ArrayVariable,
     FloatVariable,
     IntegerVariable,
-    NoneVariable,
+    NoneSegment,
+    ObjectSegment,
     ObjectVariable,
     SecretVariable,
     SegmentType,
@@ -139,10 +140,10 @@ def test_variable_to_object():
 
 
 def test_build_a_object_variable_with_none_value():
-    var = factory.build_anonymous_variable(
+    var = factory.build_segment(
         {
             'key1': None,
         }
     )
-    assert isinstance(var, ObjectVariable)
-    assert isinstance(var.value['key1'], NoneVariable)
+    assert isinstance(var, ObjectSegment)
+    assert isinstance(var.value['key1'], NoneSegment)