瀏覽代碼

feat(llm_node): support order in text and files (#11837)

Signed-off-by: -LAN- <laipz8200@outlook.com>
-LAN- 4 月之前
父節點
當前提交
996a9135f6

+ 11 - 17
api/core/file/file_manager.py

@@ -1,15 +1,14 @@
 import base64
 
 from configs import dify_config
-from core.file import file_repository
 from core.helper import ssrf_proxy
 from core.model_runtime.entities import (
     AudioPromptMessageContent,
     DocumentPromptMessageContent,
     ImagePromptMessageContent,
+    MultiModalPromptMessageContent,
     VideoPromptMessageContent,
 )
-from extensions.ext_database import db
 from extensions.ext_storage import storage
 
 from . import helpers
@@ -41,7 +40,7 @@ def to_prompt_message_content(
     /,
     *,
     image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
-):
+) -> MultiModalPromptMessageContent:
     if f.extension is None:
         raise ValueError("Missing file extension")
     if f.mime_type is None:
@@ -70,16 +69,13 @@ def to_prompt_message_content(
 
 
 def download(f: File, /):
-    if f.transfer_method == FileTransferMethod.TOOL_FILE:
-        tool_file = file_repository.get_tool_file(session=db.session(), file=f)
-        return _download_file_content(tool_file.file_key)
-    elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
-        upload_file = file_repository.get_upload_file(session=db.session(), file=f)
-        return _download_file_content(upload_file.key)
-    # remote file
-    response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
-    response.raise_for_status()
-    return response.content
+    if f.transfer_method in (FileTransferMethod.TOOL_FILE, FileTransferMethod.LOCAL_FILE):
+        return _download_file_content(f._storage_key)
+    elif f.transfer_method == FileTransferMethod.REMOTE_URL:
+        response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
+        response.raise_for_status()
+        return response.content
+    raise ValueError(f"unsupported transfer method: {f.transfer_method}")
 
 
 def _download_file_content(path: str, /):
@@ -110,11 +106,9 @@ def _get_encoded_string(f: File, /):
             response.raise_for_status()
             data = response.content
         case FileTransferMethod.LOCAL_FILE:
-            upload_file = file_repository.get_upload_file(session=db.session(), file=f)
-            data = _download_file_content(upload_file.key)
+            data = _download_file_content(f._storage_key)
         case FileTransferMethod.TOOL_FILE:
-            tool_file = file_repository.get_tool_file(session=db.session(), file=f)
-            data = _download_file_content(tool_file.file_key)
+            data = _download_file_content(f._storage_key)
 
     encoded_string = base64.b64encode(data).decode("utf-8")
     return encoded_string

+ 0 - 32
api/core/file/file_repository.py

@@ -1,32 +0,0 @@
-from sqlalchemy import select
-from sqlalchemy.orm import Session
-
-from models import ToolFile, UploadFile
-
-from .models import File
-
-
-def get_upload_file(*, session: Session, file: File):
-    if file.related_id is None:
-        raise ValueError("Missing file related_id")
-    stmt = select(UploadFile).filter(
-        UploadFile.id == file.related_id,
-        UploadFile.tenant_id == file.tenant_id,
-    )
-    record = session.scalar(stmt)
-    if not record:
-        raise ValueError(f"upload file {file.related_id} not found")
-    return record
-
-
-def get_tool_file(*, session: Session, file: File):
-    if file.related_id is None:
-        raise ValueError("Missing file related_id")
-    stmt = select(ToolFile).filter(
-        ToolFile.id == file.related_id,
-        ToolFile.tenant_id == file.tenant_id,
-    )
-    record = session.scalar(stmt)
-    if not record:
-        raise ValueError(f"tool file {file.related_id} not found")
-    return record

+ 32 - 0
api/core/file/models.py

@@ -47,6 +47,38 @@ class File(BaseModel):
     mime_type: Optional[str] = None
     size: int = -1
 
+    # Those properties are private, should not be exposed to the outside.
+    _storage_key: str
+
+    def __init__(
+        self,
+        *,
+        id: Optional[str] = None,
+        tenant_id: str,
+        type: FileType,
+        transfer_method: FileTransferMethod,
+        remote_url: Optional[str] = None,
+        related_id: Optional[str] = None,
+        filename: Optional[str] = None,
+        extension: Optional[str] = None,
+        mime_type: Optional[str] = None,
+        size: int = -1,
+        storage_key: str,
+    ):
+        super().__init__(
+            id=id,
+            tenant_id=tenant_id,
+            type=type,
+            transfer_method=transfer_method,
+            remote_url=remote_url,
+            related_id=related_id,
+            filename=filename,
+            extension=extension,
+            mime_type=mime_type,
+            size=size,
+        )
+        self._storage_key = storage_key
+
     def to_dict(self) -> Mapping[str, str | int | None]:
         data = self.model_dump(mode="json")
         return {

+ 2 - 0
api/core/model_runtime/entities/__init__.py

@@ -4,6 +4,7 @@ from .message_entities import (
     AudioPromptMessageContent,
     DocumentPromptMessageContent,
     ImagePromptMessageContent,
+    MultiModalPromptMessageContent,
     PromptMessage,
     PromptMessageContent,
     PromptMessageContentType,
@@ -27,6 +28,7 @@ __all__ = [
     "LLMResultChunkDelta",
     "LLMUsage",
     "ModelPropertyKey",
+    "MultiModalPromptMessageContent",
     "PromptMessage",
     "PromptMessage",
     "PromptMessageContent",

+ 4 - 4
api/core/model_runtime/entities/message_entities.py

@@ -84,10 +84,10 @@ class MultiModalPromptMessageContent(PromptMessageContent):
     """
 
     type: PromptMessageContentType
-    format: str = Field(..., description="the format of multi-modal file")
-    base64_data: str = Field("", description="the base64 data of multi-modal file")
-    url: str = Field("", description="the url of multi-modal file")
-    mime_type: str = Field(..., description="the mime type of multi-modal file")
+    format: str = Field(default=..., description="the format of multi-modal file")
+    base64_data: str = Field(default="", description="the base64 data of multi-modal file")
+    url: str = Field(default="", description="the url of multi-modal file")
+    mime_type: str = Field(default=..., description="the mime type of multi-modal file")
 
     @computed_field(return_type=str)
     @property

+ 1 - 0
api/core/workflow/nodes/llm/entities.py

@@ -50,6 +50,7 @@ class PromptConfig(BaseModel):
 
 
 class LLMNodeChatModelMessage(ChatModelMessage):
+    text: str = ""
     jinja2_text: Optional[str] = None
 
 

+ 69 - 76
api/core/workflow/nodes/llm/node.py

@@ -145,8 +145,8 @@ class LLMNode(BaseNode[LLMNodeData]):
                     query = query_variable.text
 
             prompt_messages, stop = self._fetch_prompt_messages(
-                user_query=query,
-                user_files=files,
+                sys_query=query,
+                sys_files=files,
                 context=context,
                 memory=memory,
                 model_config=model_config,
@@ -545,8 +545,8 @@ class LLMNode(BaseNode[LLMNodeData]):
     def _fetch_prompt_messages(
         self,
         *,
-        user_query: str | None = None,
-        user_files: Sequence["File"],
+        sys_query: str | None = None,
+        sys_files: Sequence["File"],
         context: str | None = None,
         memory: TokenBufferMemory | None = None,
         model_config: ModelConfigWithCredentialsEntity,
@@ -562,7 +562,7 @@ class LLMNode(BaseNode[LLMNodeData]):
         if isinstance(prompt_template, list):
             # For chat model
             prompt_messages.extend(
-                _handle_list_messages(
+                self._handle_list_messages(
                     messages=prompt_template,
                     context=context,
                     jinja2_variables=jinja2_variables,
@@ -581,14 +581,14 @@ class LLMNode(BaseNode[LLMNodeData]):
             prompt_messages.extend(memory_messages)
 
             # Add current query to the prompt messages
-            if user_query:
+            if sys_query:
                 message = LLMNodeChatModelMessage(
-                    text=user_query,
+                    text=sys_query,
                     role=PromptMessageRole.USER,
                     edition_type="basic",
                 )
                 prompt_messages.extend(
-                    _handle_list_messages(
+                    self._handle_list_messages(
                         messages=[message],
                         context="",
                         jinja2_variables=[],
@@ -635,24 +635,27 @@ class LLMNode(BaseNode[LLMNodeData]):
                 raise ValueError("Invalid prompt content type")
 
             # Add current query to the prompt message
-            if user_query:
+            if sys_query:
                 if prompt_content_type == str:
-                    prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query)
+                    prompt_content = prompt_messages[0].content.replace("#sys.query#", sys_query)
                     prompt_messages[0].content = prompt_content
                 elif prompt_content_type == list:
                     for content_item in prompt_content:
                         if content_item.type == PromptMessageContentType.TEXT:
-                            content_item.data = user_query + "\n" + content_item.data
+                            content_item.data = sys_query + "\n" + content_item.data
                 else:
                     raise ValueError("Invalid prompt content type")
         else:
             raise TemplateTypeNotSupportError(type_name=str(type(prompt_template)))
 
-        if vision_enabled and user_files:
+        # The sys_files will be deprecated later
+        if vision_enabled and sys_files:
             file_prompts = []
-            for file in user_files:
+            for file in sys_files:
                 file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
                 file_prompts.append(file_prompt)
+            # If last prompt is a user prompt, add files into its contents,
+            # otherwise append a new user prompt
             if (
                 len(prompt_messages) > 0
                 and isinstance(prompt_messages[-1], UserPromptMessage)
@@ -662,7 +665,7 @@ class LLMNode(BaseNode[LLMNodeData]):
             else:
                 prompt_messages.append(UserPromptMessage(content=file_prompts))
 
-        # Filter prompt messages
+        # Remove empty messages and filter unsupported content
         filtered_prompt_messages = []
         for prompt_message in prompt_messages:
             if isinstance(prompt_message.content, list):
@@ -846,6 +849,58 @@ class LLMNode(BaseNode[LLMNodeData]):
             },
         }
 
+    def _handle_list_messages(
+        self,
+        *,
+        messages: Sequence[LLMNodeChatModelMessage],
+        context: Optional[str],
+        jinja2_variables: Sequence[VariableSelector],
+        variable_pool: VariablePool,
+        vision_detail_config: ImagePromptMessageContent.DETAIL,
+    ) -> Sequence[PromptMessage]:
+        prompt_messages: list[PromptMessage] = []
+        for message in messages:
+            contents: list[PromptMessageContent] = []
+            if message.edition_type == "jinja2":
+                result_text = _render_jinja2_message(
+                    template=message.jinja2_text or "",
+                    jinjia2_variables=jinja2_variables,
+                    variable_pool=variable_pool,
+                )
+                contents.append(TextPromptMessageContent(data=result_text))
+            else:
+                # Get segment group from basic message
+                if context:
+                    template = message.text.replace("{#context#}", context)
+                else:
+                    template = message.text
+                segment_group = variable_pool.convert_template(template)
+
+                # Process segments for images
+                for segment in segment_group.value:
+                    if isinstance(segment, ArrayFileSegment):
+                        for file in segment.value:
+                            if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
+                                file_content = file_manager.to_prompt_message_content(
+                                    file, image_detail_config=vision_detail_config
+                                )
+                                contents.append(file_content)
+                    elif isinstance(segment, FileSegment):
+                        file = segment.value
+                        if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
+                            file_content = file_manager.to_prompt_message_content(
+                                file, image_detail_config=vision_detail_config
+                            )
+                            contents.append(file_content)
+                    else:
+                        plain_text = segment.markdown.strip()
+                        if plain_text:
+                            contents.append(TextPromptMessageContent(data=plain_text))
+            prompt_message = _combine_message_content_with_role(contents=contents, role=message.role)
+            prompt_messages.append(prompt_message)
+
+        return prompt_messages
+
 
 def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
     match role:
@@ -880,68 +935,6 @@ def _render_jinja2_message(
     return result_text
 
 
-def _handle_list_messages(
-    *,
-    messages: Sequence[LLMNodeChatModelMessage],
-    context: Optional[str],
-    jinja2_variables: Sequence[VariableSelector],
-    variable_pool: VariablePool,
-    vision_detail_config: ImagePromptMessageContent.DETAIL,
-) -> Sequence[PromptMessage]:
-    prompt_messages = []
-    for message in messages:
-        if message.edition_type == "jinja2":
-            result_text = _render_jinja2_message(
-                template=message.jinja2_text or "",
-                jinjia2_variables=jinja2_variables,
-                variable_pool=variable_pool,
-            )
-            prompt_message = _combine_message_content_with_role(
-                contents=[TextPromptMessageContent(data=result_text)], role=message.role
-            )
-            prompt_messages.append(prompt_message)
-        else:
-            # Get segment group from basic message
-            if context:
-                template = message.text.replace("{#context#}", context)
-            else:
-                template = message.text
-            segment_group = variable_pool.convert_template(template)
-
-            # Process segments for images
-            file_contents = []
-            for segment in segment_group.value:
-                if isinstance(segment, ArrayFileSegment):
-                    for file in segment.value:
-                        if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
-                            file_content = file_manager.to_prompt_message_content(
-                                file, image_detail_config=vision_detail_config
-                            )
-                            file_contents.append(file_content)
-                if isinstance(segment, FileSegment):
-                    file = segment.value
-                    if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
-                        file_content = file_manager.to_prompt_message_content(
-                            file, image_detail_config=vision_detail_config
-                        )
-                        file_contents.append(file_content)
-
-            # Create message with text from all segments
-            plain_text = segment_group.text
-            if plain_text:
-                prompt_message = _combine_message_content_with_role(
-                    contents=[TextPromptMessageContent(data=plain_text)], role=message.role
-                )
-                prompt_messages.append(prompt_message)
-
-            if file_contents:
-                # Create message with image contents
-                prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
-                prompt_messages.append(prompt_message)
-
-    return prompt_messages
-
-
 def _calculate_rest_token(
     *, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
 ) -> int:

+ 2 - 2
api/core/workflow/nodes/question_classifier/question_classifier_node.py

@@ -86,10 +86,10 @@ class QuestionClassifierNode(LLMNode):
         )
         prompt_messages, stop = self._fetch_prompt_messages(
             prompt_template=prompt_template,
-            user_query=query,
+            sys_query=query,
             memory=memory,
             model_config=model_config,
-            user_files=files,
+            sys_files=files,
             vision_enabled=node_data.vision.enabled,
             vision_detail=node_data.vision.configs.detail,
             variable_pool=variable_pool,

+ 3 - 0
api/factories/file_factory.py

@@ -139,6 +139,7 @@ def _build_from_local_file(
         remote_url=row.source_url,
         related_id=mapping.get("upload_file_id"),
         size=row.size,
+        storage_key=row.key,
     )
 
 
@@ -168,6 +169,7 @@ def _build_from_remote_url(
         mime_type=mime_type,
         extension=extension,
         size=file_size,
+        storage_key="",
     )
 
 
@@ -220,6 +222,7 @@ def _build_from_tool_file(
         extension=extension,
         mime_type=tool_file.mimetype,
         size=tool_file.size,
+        storage_key=tool_file.file_key,
     )
 
 

+ 33 - 4
api/models/model.py

@@ -560,13 +560,29 @@ class Conversation(db.Model):
     @property
     def inputs(self):
         inputs = self._inputs.copy()
+
+        # Convert file mapping to File object
         for key, value in inputs.items():
+            # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now.
+            from factories import file_factory
+
             if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
-                inputs[key] = File.model_validate(value)
+                if value["transfer_method"] == FileTransferMethod.TOOL_FILE:
+                    value["tool_file_id"] = value["related_id"]
+                elif value["transfer_method"] == FileTransferMethod.LOCAL_FILE:
+                    value["upload_file_id"] = value["related_id"]
+                inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"])
             elif isinstance(value, list) and all(
                 isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value
             ):
-                inputs[key] = [File.model_validate(item) for item in value]
+                inputs[key] = []
+                for item in value:
+                    if item["transfer_method"] == FileTransferMethod.TOOL_FILE:
+                        item["tool_file_id"] = item["related_id"]
+                    elif item["transfer_method"] == FileTransferMethod.LOCAL_FILE:
+                        item["upload_file_id"] = item["related_id"]
+                    inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"]))
+
         return inputs
 
     @inputs.setter
@@ -758,12 +774,25 @@ class Message(db.Model):
     def inputs(self):
         inputs = self._inputs.copy()
         for key, value in inputs.items():
+            # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now.
+            from factories import file_factory
+
             if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
-                inputs[key] = File.model_validate(value)
+                if value["transfer_method"] == FileTransferMethod.TOOL_FILE:
+                    value["tool_file_id"] = value["related_id"]
+                elif value["transfer_method"] == FileTransferMethod.LOCAL_FILE:
+                    value["upload_file_id"] = value["related_id"]
+                inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"])
             elif isinstance(value, list) and all(
                 isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value
             ):
-                inputs[key] = [File.model_validate(item) for item in value]
+                inputs[key] = []
+                for item in value:
+                    if item["transfer_method"] == FileTransferMethod.TOOL_FILE:
+                        item["tool_file_id"] = item["related_id"]
+                    elif item["transfer_method"] == FileTransferMethod.LOCAL_FILE:
+                        item["upload_file_id"] = item["related_id"]
+                    inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"]))
         return inputs
 
     @inputs.setter

+ 1 - 0
api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py

@@ -136,6 +136,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
             type=FileType.IMAGE,
             transfer_method=FileTransferMethod.REMOTE_URL,
             remote_url="https://example.com/image1.jpg",
+            storage_key="",
         )
     ]
 

+ 3 - 27
api/tests/unit_tests/core/test_file.py

@@ -1,34 +1,9 @@
 import json
 
-from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType, FileUploadConfig
+from core.file import File, FileTransferMethod, FileType, FileUploadConfig
 from models.workflow import Workflow
 
 
-def test_file_loads_and_dumps():
-    file = File(
-        id="file1",
-        tenant_id="tenant1",
-        type=FileType.IMAGE,
-        transfer_method=FileTransferMethod.REMOTE_URL,
-        remote_url="https://example.com/image1.jpg",
-    )
-
-    file_dict = file.model_dump()
-    assert file_dict["dify_model_identity"] == FILE_MODEL_IDENTITY
-    assert file_dict["type"] == file.type.value
-    assert isinstance(file_dict["type"], str)
-    assert file_dict["transfer_method"] == file.transfer_method.value
-    assert isinstance(file_dict["transfer_method"], str)
-    assert "_extra_config" not in file_dict
-
-    file_obj = File.model_validate(file_dict)
-    assert file_obj.id == file.id
-    assert file_obj.tenant_id == file.tenant_id
-    assert file_obj.type == file.type
-    assert file_obj.transfer_method == file.transfer_method
-    assert file_obj.remote_url == file.remote_url
-
-
 def test_file_to_dict():
     file = File(
         id="file1",
@@ -36,10 +11,11 @@ def test_file_to_dict():
         type=FileType.IMAGE,
         transfer_method=FileTransferMethod.REMOTE_URL,
         remote_url="https://example.com/image1.jpg",
+        storage_key="storage_key",
     )
 
     file_dict = file.to_dict()
-    assert "_extra_config" not in file_dict
+    assert "_storage_key" not in file_dict
     assert "url" in file_dict
 
 

+ 2 - 0
api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py

@@ -51,6 +51,7 @@ def test_http_request_node_binary_file(monkeypatch):
                 type=FileType.IMAGE,
                 transfer_method=FileTransferMethod.LOCAL_FILE,
                 related_id="1111",
+                storage_key="",
             ),
         ),
     )
@@ -138,6 +139,7 @@ def test_http_request_node_form_with_file(monkeypatch):
                 type=FileType.IMAGE,
                 transfer_method=FileTransferMethod.LOCAL_FILE,
                 related_id="1111",
+                storage_key="",
             ),
         ),
     )

+ 44 - 11
api/tests/unit_tests/core/workflow/nodes/llm/test_node.py

@@ -21,7 +21,8 @@ from core.model_runtime.entities.message_entities import (
 from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType
 from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
 from core.prompt.entities.advanced_prompt_entities import MemoryConfig
-from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
+from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment, StringSegment
+from core.workflow.entities.variable_entities import VariableSelector
 from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
 from core.workflow.nodes.answer import AnswerStreamGenerateRoute
@@ -157,6 +158,7 @@ def test_fetch_files_with_file_segment(llm_node):
         filename="test.jpg",
         transfer_method=FileTransferMethod.LOCAL_FILE,
         related_id="1",
+        storage_key="",
     )
     llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
 
@@ -173,6 +175,7 @@ def test_fetch_files_with_array_file_segment(llm_node):
             filename="test1.jpg",
             transfer_method=FileTransferMethod.LOCAL_FILE,
             related_id="1",
+            storage_key="",
         ),
         File(
             id="2",
@@ -181,6 +184,7 @@ def test_fetch_files_with_array_file_segment(llm_node):
             filename="test2.jpg",
             transfer_method=FileTransferMethod.LOCAL_FILE,
             related_id="2",
+            storage_key="",
         ),
     ]
     llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
@@ -224,14 +228,15 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
             filename="test1.jpg",
             transfer_method=FileTransferMethod.REMOTE_URL,
             remote_url=fake_remote_url,
+            storage_key="",
         )
     ]
 
     fake_query = faker.sentence()
 
     prompt_messages, _ = llm_node._fetch_prompt_messages(
-        user_query=fake_query,
-        user_files=files,
+        sys_query=fake_query,
+        sys_files=files,
         context=None,
         memory=None,
         model_config=model_config,
@@ -283,8 +288,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
     test_scenarios = [
         LLMNodeTestScenario(
             description="No files",
-            user_query=fake_query,
-            user_files=[],
+            sys_query=fake_query,
+            sys_files=[],
             features=[],
             vision_enabled=False,
             vision_detail=None,
@@ -318,8 +323,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
         ),
         LLMNodeTestScenario(
             description="User files",
-            user_query=fake_query,
-            user_files=[
+            sys_query=fake_query,
+            sys_files=[
                 File(
                     tenant_id="test",
                     type=FileType.IMAGE,
@@ -328,6 +333,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
                     remote_url=fake_remote_url,
                     extension=".jpg",
                     mime_type="image/jpg",
+                    storage_key="",
                 )
             ],
             vision_enabled=True,
@@ -370,8 +376,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
         ),
         LLMNodeTestScenario(
             description="Prompt template with variable selector of File",
-            user_query=fake_query,
-            user_files=[],
+            sys_query=fake_query,
+            sys_files=[],
             vision_enabled=False,
             vision_detail=fake_vision_detail,
             features=[ModelFeature.VISION],
@@ -403,6 +409,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
                     remote_url=fake_remote_url,
                     extension=".jpg",
                     mime_type="image/jpg",
+                    storage_key="",
                 )
             },
         ),
@@ -417,8 +424,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
 
         # Call the method under test
         prompt_messages, _ = llm_node._fetch_prompt_messages(
-            user_query=scenario.user_query,
-            user_files=scenario.user_files,
+            sys_query=scenario.sys_query,
+            sys_files=scenario.sys_files,
             context=fake_context,
             memory=memory,
             model_config=model_config,
@@ -435,3 +442,29 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
         assert (
             prompt_messages == scenario.expected_messages
         ), f"Message content mismatch in scenario: {scenario.description}"
+
+
+def test_handle_list_messages_basic(llm_node):
+    messages = [
+        LLMNodeChatModelMessage(
+            text="Hello, {#context#}",
+            role=PromptMessageRole.USER,
+            edition_type="basic",
+        )
+    ]
+    context = "world"
+    jinja2_variables = []
+    variable_pool = llm_node.graph_runtime_state.variable_pool
+    vision_detail_config = ImagePromptMessageContent.DETAIL.HIGH
+
+    result = llm_node._handle_list_messages(
+        messages=messages,
+        context=context,
+        jinja2_variables=jinja2_variables,
+        variable_pool=variable_pool,
+        vision_detail_config=vision_detail_config,
+    )
+
+    assert len(result) == 1
+    assert isinstance(result[0], UserPromptMessage)
+    assert result[0].content == [TextPromptMessageContent(data="Hello, world")]

+ 2 - 2
api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py

@@ -12,8 +12,8 @@ class LLMNodeTestScenario(BaseModel):
     """Test scenario for LLM node testing."""
 
     description: str = Field(..., description="Description of the test scenario")
-    user_query: str = Field(..., description="User query input")
-    user_files: Sequence[File] = Field(default_factory=list, description="List of user files")
+    sys_query: str = Field(..., description="User query input")
+    sys_files: Sequence[File] = Field(default_factory=list, description="List of user files")
     vision_enabled: bool = Field(default=False, description="Whether vision is enabled")
     vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled")
     features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features")

+ 1 - 0
api/tests/unit_tests/core/workflow/nodes/test_if_else.py

@@ -248,6 +248,7 @@ def test_array_file_contains_file_name():
                 transfer_method=FileTransferMethod.LOCAL_FILE,
                 related_id="1",
                 filename="ab",
+                storage_key="",
             ),
         ],
     )

+ 6 - 0
api/tests/unit_tests/core/workflow/nodes/test_list_operator.py

@@ -57,6 +57,7 @@ def test_filter_files_by_type(list_operator_node):
             tenant_id="tenant1",
             transfer_method=FileTransferMethod.LOCAL_FILE,
             related_id="related1",
+            storage_key="",
         ),
         File(
             filename="document1.pdf",
@@ -64,6 +65,7 @@ def test_filter_files_by_type(list_operator_node):
             tenant_id="tenant1",
             transfer_method=FileTransferMethod.LOCAL_FILE,
             related_id="related2",
+            storage_key="",
         ),
         File(
             filename="image2.png",
@@ -71,6 +73,7 @@ def test_filter_files_by_type(list_operator_node):
             tenant_id="tenant1",
             transfer_method=FileTransferMethod.LOCAL_FILE,
             related_id="related3",
+            storage_key="",
         ),
         File(
             filename="audio1.mp3",
@@ -78,6 +81,7 @@ def test_filter_files_by_type(list_operator_node):
             tenant_id="tenant1",
             transfer_method=FileTransferMethod.LOCAL_FILE,
             related_id="related4",
+            storage_key="",
         ),
     ]
     variable = ArrayFileSegment(value=files)
@@ -130,6 +134,7 @@ def test_get_file_extract_string_func():
         mime_type="text/plain",
         remote_url="https://example.com/test_file.txt",
         related_id="test_related_id",
+        storage_key="",
     )
 
     # Test each case
@@ -150,6 +155,7 @@ def test_get_file_extract_string_func():
         mime_type=None,
         remote_url=None,
         related_id="test_related_id",
+        storage_key="",
     )
 
     assert _get_file_extract_string_func(key="name")(empty_file) == ""

+ 1 - 0
api/tests/unit_tests/core/workflow/test_variable_pool.py

@@ -19,6 +19,7 @@ def file():
         related_id="test_related_id",
         remote_url="test_url",
         filename="test_file.txt",
+        storage_key="",
     )