Forráskód Böngészése

fix(workflow): refine variable type checks in LLMNode (#10051)

-LAN- 5 hónapja
szülő
commit
3b53e06e0d

+ 3 - 5
api/core/workflow/nodes/llm/node.py

@@ -349,13 +349,11 @@ class LLMNode(BaseNode[LLMNodeData]):
         variable = self.graph_runtime_state.variable_pool.get(selector)
         if variable is None:
             return []
-        if isinstance(variable, FileSegment):
+        elif isinstance(variable, FileSegment):
             return [variable.value]
-        if isinstance(variable, ArrayFileSegment):
+        elif isinstance(variable, ArrayFileSegment):
             return variable.value
-        # FIXME: Temporary fix for empty array,
-        # all variables added to variable pool should be a Segment instance.
-        if isinstance(variable, ArrayAnySegment) and len(variable.value) == 0:
+        elif isinstance(variable, NoneSegment | ArrayAnySegment):
             return []
         raise ValueError(f"Invalid variable type: {type(variable)}")
 

+ 125 - 0
api/tests/unit_tests/core/workflow/nodes/llm/test_node.py

@@ -0,0 +1,125 @@
+import pytest
+
+from core.app.entities.app_invoke_entities import InvokeFrom
+from core.file import File, FileTransferMethod, FileType
+from core.model_runtime.entities.message_entities import ImagePromptMessageContent
+from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
+from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
+from core.workflow.nodes.answer import AnswerStreamGenerateRoute
+from core.workflow.nodes.end import EndStreamParam
+from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig, VisionConfig, VisionConfigOptions
+from core.workflow.nodes.llm.node import LLMNode
+from models.enums import UserFrom
+from models.workflow import WorkflowType
+
+
+class TestLLMNode:
+    @pytest.fixture
+    def llm_node(self):
+        data = LLMNodeData(
+            title="Test LLM",
+            model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
+            prompt_template=[],
+            memory=None,
+            context=ContextConfig(enabled=False),
+            vision=VisionConfig(
+                enabled=True,
+                configs=VisionConfigOptions(
+                    variable_selector=["sys", "files"],
+                    detail=ImagePromptMessageContent.DETAIL.HIGH,
+                ),
+            ),
+        )
+        variable_pool = VariablePool(
+            system_variables={},
+            user_inputs={},
+        )
+        node = LLMNode(
+            id="1",
+            config={
+                "id": "1",
+                "data": data.model_dump(),
+            },
+            graph_init_params=GraphInitParams(
+                tenant_id="1",
+                app_id="1",
+                workflow_type=WorkflowType.WORKFLOW,
+                workflow_id="1",
+                graph_config={},
+                user_id="1",
+                user_from=UserFrom.ACCOUNT,
+                invoke_from=InvokeFrom.SERVICE_API,
+                call_depth=0,
+            ),
+            graph=Graph(
+                root_node_id="1",
+                answer_stream_generate_routes=AnswerStreamGenerateRoute(
+                    answer_dependencies={},
+                    answer_generate_route={},
+                ),
+                end_stream_param=EndStreamParam(
+                    end_dependencies={},
+                    end_stream_variable_selector_mapping={},
+                ),
+            ),
+            graph_runtime_state=GraphRuntimeState(
+                variable_pool=variable_pool,
+                start_at=0,
+            ),
+        )
+        return node
+
+    def test_fetch_files_with_file_segment(self, llm_node):
+        file = File(
+            id="1",
+            tenant_id="test",
+            type=FileType.IMAGE,
+            filename="test.jpg",
+            transfer_method=FileTransferMethod.LOCAL_FILE,
+            related_id="1",
+        )
+        llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
+
+        result = llm_node._fetch_files(selector=["sys", "files"])
+        assert result == [file]
+
+    def test_fetch_files_with_array_file_segment(self, llm_node):
+        files = [
+            File(
+                id="1",
+                tenant_id="test",
+                type=FileType.IMAGE,
+                filename="test1.jpg",
+                transfer_method=FileTransferMethod.LOCAL_FILE,
+                related_id="1",
+            ),
+            File(
+                id="2",
+                tenant_id="test",
+                type=FileType.IMAGE,
+                filename="test2.jpg",
+                transfer_method=FileTransferMethod.LOCAL_FILE,
+                related_id="2",
+            ),
+        ]
+        llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
+
+        result = llm_node._fetch_files(selector=["sys", "files"])
+        assert result == files
+
+    def test_fetch_files_with_none_segment(self, llm_node):
+        llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
+
+        result = llm_node._fetch_files(selector=["sys", "files"])
+        assert result == []
+
+    def test_fetch_files_with_array_any_segment(self, llm_node):
+        llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
+
+        result = llm_node._fetch_files(selector=["sys", "files"])
+        assert result == []
+
+    def test_fetch_files_with_non_existent_variable(self, llm_node):
+        result = llm_node._fetch_files(selector=["sys", "files"])
+        assert result == []