Forráskód Böngészése

refactor(api/core/workflow/nodes/variable_assigner): Split into multi files. (#7434)

-LAN- 8 hónapja
szülő
commit
4f64a5d36d

+ 8 - 109
api/core/workflow/nodes/variable_assigner/__init__.py

@@ -1,109 +1,8 @@
-from collections.abc import Sequence
-from enum import Enum
-from typing import Optional, cast
-
-from sqlalchemy import select
-from sqlalchemy.orm import Session
-
-from core.app.segments import SegmentType, Variable, factory
-from core.workflow.entities.base_node_data_entities import BaseNodeData
-from core.workflow.entities.node_entities import NodeRunResult, NodeType
-from core.workflow.entities.variable_pool import VariablePool
-from core.workflow.nodes.base_node import BaseNode
-from extensions.ext_database import db
-from models import ConversationVariable, WorkflowNodeExecutionStatus
-
-
-class VariableAssignerNodeError(Exception):
-    pass
-
-
-class WriteMode(str, Enum):
-    OVER_WRITE = 'over-write'
-    APPEND = 'append'
-    CLEAR = 'clear'
-
-
-class VariableAssignerData(BaseNodeData):
-    title: str = 'Variable Assigner'
-    desc: Optional[str] = 'Assign a value to a variable'
-    assigned_variable_selector: Sequence[str]
-    write_mode: WriteMode
-    input_variable_selector: Sequence[str]
-
-
-class VariableAssignerNode(BaseNode):
-    _node_data_cls: type[BaseNodeData] = VariableAssignerData
-    _node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
-
-    def _run(self, variable_pool: VariablePool) -> NodeRunResult:
-        data = cast(VariableAssignerData, self.node_data)
-
-        # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
-        original_variable = variable_pool.get(data.assigned_variable_selector)
-        if not isinstance(original_variable, Variable):
-            raise VariableAssignerNodeError('assigned variable not found')
-
-        match data.write_mode:
-            case WriteMode.OVER_WRITE:
-                income_value = variable_pool.get(data.input_variable_selector)
-                if not income_value:
-                    raise VariableAssignerNodeError('input value not found')
-                updated_variable = original_variable.model_copy(update={'value': income_value.value})
-
-            case WriteMode.APPEND:
-                income_value = variable_pool.get(data.input_variable_selector)
-                if not income_value:
-                    raise VariableAssignerNodeError('input value not found')
-                updated_value = original_variable.value + [income_value.value]
-                updated_variable = original_variable.model_copy(update={'value': updated_value})
-
-            case WriteMode.CLEAR:
-                income_value = get_zero_value(original_variable.value_type)
-                updated_variable = original_variable.model_copy(update={'value': income_value.to_object()})
-
-            case _:
-                raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}')
-
-        # Over write the variable.
-        variable_pool.add(data.assigned_variable_selector, updated_variable)
-
-        # Update conversation variable.
-        # TODO: Find a better way to use the database.
-        conversation_id = variable_pool.get(['sys', 'conversation_id'])
-        if not conversation_id:
-            raise VariableAssignerNodeError('conversation_id not found')
-        update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
-
-        return NodeRunResult(
-            status=WorkflowNodeExecutionStatus.SUCCEEDED,
-            inputs={
-                'value': income_value.to_object(),
-            },
-        )
-
-
-def update_conversation_variable(conversation_id: str, variable: Variable):
-    stmt = select(ConversationVariable).where(
-        ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
-    )
-    with Session(db.engine) as session:
-        row = session.scalar(stmt)
-        if not row:
-            raise VariableAssignerNodeError('conversation variable not found in the database')
-        row.data = variable.model_dump_json()
-        session.commit()
-
-
-def get_zero_value(t: SegmentType):
-    match t:
-        case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
-            return factory.build_segment([])
-        case SegmentType.OBJECT:
-            return factory.build_segment({})
-        case SegmentType.STRING:
-            return factory.build_segment('')
-        case SegmentType.NUMBER:
-            return factory.build_segment(0)
-        case _:
-            raise VariableAssignerNodeError(f'unsupported variable type: {t}')
+from .node import VariableAssignerNode
+from .node_data import VariableAssignerData, WriteMode
+
+__all__ = [
+    'VariableAssignerNode',
+    'VariableAssignerData',
+    'WriteMode',
+]

+ 2 - 0
api/core/workflow/nodes/variable_assigner/exc.py

@@ -0,0 +1,2 @@
+class VariableAssignerNodeError(Exception):
+    pass

+ 92 - 0
api/core/workflow/nodes/variable_assigner/node.py

@@ -0,0 +1,92 @@
+from typing import cast
+
+from sqlalchemy import select
+from sqlalchemy.orm import Session
+
+from core.app.segments import SegmentType, Variable, factory
+from core.workflow.entities.base_node_data_entities import BaseNodeData
+from core.workflow.entities.node_entities import NodeRunResult, NodeType
+from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.nodes.base_node import BaseNode
+from extensions.ext_database import db
+from models import ConversationVariable, WorkflowNodeExecutionStatus
+
+from .exc import VariableAssignerNodeError
+from .node_data import VariableAssignerData, WriteMode
+
+
+class VariableAssignerNode(BaseNode):
+    _node_data_cls: type[BaseNodeData] = VariableAssignerData
+    _node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
+
+    def _run(self, variable_pool: VariablePool) -> NodeRunResult:
+        data = cast(VariableAssignerData, self.node_data)
+
+        # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
+        original_variable = variable_pool.get(data.assigned_variable_selector)
+        if not isinstance(original_variable, Variable):
+            raise VariableAssignerNodeError('assigned variable not found')
+
+        match data.write_mode:
+            case WriteMode.OVER_WRITE:
+                income_value = variable_pool.get(data.input_variable_selector)
+                if not income_value:
+                    raise VariableAssignerNodeError('input value not found')
+                updated_variable = original_variable.model_copy(update={'value': income_value.value})
+
+            case WriteMode.APPEND:
+                income_value = variable_pool.get(data.input_variable_selector)
+                if not income_value:
+                    raise VariableAssignerNodeError('input value not found')
+                updated_value = original_variable.value + [income_value.value]
+                updated_variable = original_variable.model_copy(update={'value': updated_value})
+
+            case WriteMode.CLEAR:
+                income_value = get_zero_value(original_variable.value_type)
+                updated_variable = original_variable.model_copy(update={'value': income_value.to_object()})
+
+            case _:
+                raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}')
+
+        # Over write the variable.
+        variable_pool.add(data.assigned_variable_selector, updated_variable)
+
+        # TODO: Move database operation to the pipeline.
+        # Update conversation variable.
+        conversation_id = variable_pool.get(['sys', 'conversation_id'])
+        if not conversation_id:
+            raise VariableAssignerNodeError('conversation_id not found')
+        update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
+
+        return NodeRunResult(
+            status=WorkflowNodeExecutionStatus.SUCCEEDED,
+            inputs={
+                'value': income_value.to_object(),
+            },
+        )
+
+
+def update_conversation_variable(conversation_id: str, variable: Variable):
+    stmt = select(ConversationVariable).where(
+        ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
+    )
+    with Session(db.engine) as session:
+        row = session.scalar(stmt)
+        if not row:
+            raise VariableAssignerNodeError('conversation variable not found in the database')
+        row.data = variable.model_dump_json()
+        session.commit()
+
+
+def get_zero_value(t: SegmentType):
+    match t:
+        case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
+            return factory.build_segment([])
+        case SegmentType.OBJECT:
+            return factory.build_segment({})
+        case SegmentType.STRING:
+            return factory.build_segment('')
+        case SegmentType.NUMBER:
+            return factory.build_segment(0)
+        case _:
+            raise VariableAssignerNodeError(f'unsupported variable type: {t}')

+ 19 - 0
api/core/workflow/nodes/variable_assigner/node_data.py

@@ -0,0 +1,19 @@
+from collections.abc import Sequence
+from enum import Enum
+from typing import Optional
+
+from core.workflow.entities.base_node_data_entities import BaseNodeData
+
+
+class WriteMode(str, Enum):
+    OVER_WRITE = 'over-write'
+    APPEND = 'append'
+    CLEAR = 'clear'
+
+
+class VariableAssignerData(BaseNodeData):
+    title: str = 'Variable Assigner'
+    desc: Optional[str] = 'Assign a value to a variable'
+    assigned_variable_selector: Sequence[str]
+    write_mode: WriteMode
+    input_variable_selector: Sequence[str]

+ 2 - 2
api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py

@@ -52,7 +52,7 @@ def test_overwrite_string_variable():
         input_variable,
     )
 
-    with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run:
+    with mock.patch('core.workflow.nodes.variable_assigner.node.update_conversation_variable') as mock_run:
         node.run(variable_pool)
         mock_run.assert_called_once()
 
@@ -103,7 +103,7 @@ def test_append_variable_to_array():
         input_variable,
     )
 
-    with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run:
+    with mock.patch('core.workflow.nodes.variable_assigner.node.update_conversation_variable') as mock_run:
         node.run(variable_pool)
         mock_run.assert_called_once()