Просмотр исходного кода

feat: Allow using file variables directly in the LLM node and support more file types. (#10679)

Co-authored-by: Joel <iamjoel007@gmail.com>
-LAN- 5 месяцев назад
Родитель
Сommit
c5f7d650b5
36 измененных файлов с 1036 добавлено и 268 удалено
  1. 0 1
      api/configs/app_config.py
  2. 19 23
      api/core/app/app_config/easy_ui_based_app/model_config/converter.py
  3. 4 1
      api/core/app/task_pipeline/workflow_cycle_manage.py
  4. 25 44
      api/core/file/file_manager.py
  5. 2 1
      api/core/memory/token_buffer_memory.py
  6. 2 2
      api/core/model_manager.py
  7. 5 4
      api/core/model_runtime/callbacks/base_callback.py
  8. 2 0
      api/core/model_runtime/entities/__init__.py
  9. 11 2
      api/core/model_runtime/entities/message_entities.py
  10. 3 0
      api/core/model_runtime/entities/model_entities.py
  11. 10 10
      api/core/model_runtime/model_providers/__base/large_language_model.py
  12. 1 0
      api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20240620.yaml
  13. 1 0
      api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml
  14. 30 6
      api/core/model_runtime/model_providers/anthropic/llm/llm.py
  15. 1 0
      api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml
  16. 2 1
      api/core/prompt/utils/prompt_message_util.py
  17. 10 2
      api/core/variables/segments.py
  18. 16 2
      api/core/workflow/nodes/llm/entities.py
  19. 8 0
      api/core/workflow/nodes/llm/exc.py
  20. 357 55
      api/core/workflow/nodes/llm/node.py
  21. 4 2
      api/core/workflow/nodes/question_classifier/question_classifier_node.py
  22. 16 1
      api/poetry.lock
  23. 1 0
      api/pyproject.toml
  24. 0 1
      api/tests/integration_tests/model_runtime/azure_ai_studio/test_llm.py
  25. 3 11
      api/tests/integration_tests/model_runtime/azure_ai_studio/test_rerank.py
  26. 455 96
      api/tests/unit_tests/core/workflow/nodes/llm/test_node.py
  27. 25 0
      api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py
  28. 1 0
      web/app/components/workflow/nodes/_base/components/editor/code-editor/editor-support-vars.tsx
  29. 3 0
      web/app/components/workflow/nodes/_base/components/variable/var-list.tsx
  30. 3 0
      web/app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx
  31. 3 1
      web/app/components/workflow/nodes/_base/components/variable/var-reference-popup.tsx
  32. 1 0
      web/app/components/workflow/nodes/code/panel.tsx
  33. 1 0
      web/app/components/workflow/nodes/llm/components/config-prompt-item.tsx
  34. 4 1
      web/app/components/workflow/nodes/llm/panel.tsx
  35. 6 1
      web/app/components/workflow/nodes/llm/use-config.ts
  36. 1 0
      web/app/components/workflow/nodes/template-transform/panel.tsx

+ 0 - 1
api/configs/app_config.py

@@ -27,7 +27,6 @@ class DifyConfig(
         # read from dotenv format config file
         env_file=".env",
         env_file_encoding="utf-8",
-        frozen=True,
         # ignore extra attributes
         extra="ignore",
     )

+ 19 - 23
api/core/app/app_config/easy_ui_based_app/model_config/converter.py

@@ -11,7 +11,7 @@ from core.provider_manager import ProviderManager
 
 class ModelConfigConverter:
     @classmethod
-    def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity:
+    def convert(cls, app_config: EasyUIBasedAppConfig) -> ModelConfigWithCredentialsEntity:
         """
         Convert app model config dict to entity.
         :param app_config: app config
@@ -38,27 +38,23 @@ class ModelConfigConverter:
         )
 
         if model_credentials is None:
-            if not skip_check:
-                raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
-            else:
-                model_credentials = {}
-
-        if not skip_check:
-            # check model
-            provider_model = provider_model_bundle.configuration.get_provider_model(
-                model=model_config.model, model_type=ModelType.LLM
-            )
-
-            if provider_model is None:
-                model_name = model_config.model
-                raise ValueError(f"Model {model_name} not exist.")
-
-            if provider_model.status == ModelStatus.NO_CONFIGURE:
-                raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
-            elif provider_model.status == ModelStatus.NO_PERMISSION:
-                raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
-            elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
-                raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
+            raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
+
+        # check model
+        provider_model = provider_model_bundle.configuration.get_provider_model(
+            model=model_config.model, model_type=ModelType.LLM
+        )
+
+        if provider_model is None:
+            model_name = model_config.model
+            raise ValueError(f"Model {model_name} not exist.")
+
+        if provider_model.status == ModelStatus.NO_CONFIGURE:
+            raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
+        elif provider_model.status == ModelStatus.NO_PERMISSION:
+            raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
+        elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
+            raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
 
         # model config
         completion_params = model_config.parameters
@@ -76,7 +72,7 @@ class ModelConfigConverter:
 
         model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
 
-        if not skip_check and not model_schema:
+        if not model_schema:
             raise ValueError(f"Model {model_name} not exist.")
 
         return ModelConfigWithCredentialsEntity(

+ 4 - 1
api/core/app/task_pipeline/workflow_cycle_manage.py

@@ -217,9 +217,12 @@ class WorkflowCycleManage:
             ).total_seconds()
             db.session.commit()
 
-        db.session.refresh(workflow_run)
         db.session.close()
 
+        with Session(db.engine, expire_on_commit=False) as session:
+            session.add(workflow_run)
+            session.refresh(workflow_run)
+
         if trace_manager:
             trace_manager.add_trace_task(
                 TraceTask(

+ 25 - 44
api/core/file/file_manager.py

@@ -3,7 +3,12 @@ import base64
 from configs import dify_config
 from core.file import file_repository
 from core.helper import ssrf_proxy
-from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent, VideoPromptMessageContent
+from core.model_runtime.entities import (
+    AudioPromptMessageContent,
+    DocumentPromptMessageContent,
+    ImagePromptMessageContent,
+    VideoPromptMessageContent,
+)
 from extensions.ext_database import db
 from extensions.ext_storage import storage
 
@@ -29,35 +34,17 @@ def get_attr(*, file: File, attr: FileAttribute):
             return file.remote_url
         case FileAttribute.EXTENSION:
             return file.extension
-        case _:
-            raise ValueError(f"Invalid file attribute: {attr}")
 
 
 def to_prompt_message_content(
     f: File,
     /,
     *,
-    image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW,
+    image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
 ):
-    """
-    Convert a File object to an ImagePromptMessageContent or AudioPromptMessageContent object.
-
-    This function takes a File object and converts it to an appropriate PromptMessageContent
-    object, which can be used as a prompt for image or audio-based AI models.
-
-    Args:
-        f (File): The File object to convert.
-        detail (Optional[ImagePromptMessageContent.DETAIL]): The detail level for image prompts.
-            If not provided, defaults to ImagePromptMessageContent.DETAIL.LOW.
-
-    Returns:
-        Union[ImagePromptMessageContent, AudioPromptMessageContent]: An object containing the file data and detail level
-
-    Raises:
-        ValueError: If the file type is not supported or if required data is missing.
-    """
     match f.type:
         case FileType.IMAGE:
+            image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
             if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
                 data = _to_url(f)
             else:
@@ -65,7 +52,7 @@ def to_prompt_message_content(
 
             return ImagePromptMessageContent(data=data, detail=image_detail_config)
         case FileType.AUDIO:
-            encoded_string = _file_to_encoded_string(f)
+            encoded_string = _get_encoded_string(f)
             if f.extension is None:
                 raise ValueError("Missing file extension")
             return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
@@ -74,9 +61,20 @@ def to_prompt_message_content(
                 data = _to_url(f)
             else:
                 data = _to_base64_data_string(f)
+            if f.extension is None:
+                raise ValueError("Missing file extension")
             return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
+        case FileType.DOCUMENT:
+            data = _get_encoded_string(f)
+            if f.mime_type is None:
+                raise ValueError("Missing file mime_type")
+            return DocumentPromptMessageContent(
+                encode_format="base64",
+                mime_type=f.mime_type,
+                data=data,
+            )
         case _:
-            raise ValueError("file type f.type is not supported")
+            raise ValueError(f"file type {f.type} is not supported")
 
 
 def download(f: File, /):
@@ -118,21 +116,16 @@ def _get_encoded_string(f: File, /):
         case FileTransferMethod.REMOTE_URL:
             response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
             response.raise_for_status()
-            content = response.content
-            encoded_string = base64.b64encode(content).decode("utf-8")
-            return encoded_string
+            data = response.content
         case FileTransferMethod.LOCAL_FILE:
             upload_file = file_repository.get_upload_file(session=db.session(), file=f)
             data = _download_file_content(upload_file.key)
-            encoded_string = base64.b64encode(data).decode("utf-8")
-            return encoded_string
         case FileTransferMethod.TOOL_FILE:
             tool_file = file_repository.get_tool_file(session=db.session(), file=f)
             data = _download_file_content(tool_file.file_key)
-            encoded_string = base64.b64encode(data).decode("utf-8")
-            return encoded_string
-        case _:
-            raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
+
+    encoded_string = base64.b64encode(data).decode("utf-8")
+    return encoded_string
 
 
 def _to_base64_data_string(f: File, /):
@@ -140,18 +133,6 @@ def _to_base64_data_string(f: File, /):
     return f"data:{f.mime_type};base64,{encoded_string}"
 
 
-def _file_to_encoded_string(f: File, /):
-    match f.type:
-        case FileType.IMAGE:
-            return _to_base64_data_string(f)
-        case FileType.VIDEO:
-            return _to_base64_data_string(f)
-        case FileType.AUDIO:
-            return _get_encoded_string(f)
-        case _:
-            raise ValueError(f"file type {f.type} is not supported")
-
-
 def _to_url(f: File, /):
     if f.transfer_method == FileTransferMethod.REMOTE_URL:
         if f.remote_url is None:

+ 2 - 1
api/core/memory/token_buffer_memory.py

@@ -1,3 +1,4 @@
+from collections.abc import Sequence
 from typing import Optional
 
 from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
@@ -27,7 +28,7 @@ class TokenBufferMemory:
 
     def get_history_prompt_messages(
         self, max_token_limit: int = 2000, message_limit: Optional[int] = None
-    ) -> list[PromptMessage]:
+    ) -> Sequence[PromptMessage]:
         """
         Get history prompt messages.
         :param max_token_limit: max token limit

+ 2 - 2
api/core/model_manager.py

@@ -100,10 +100,10 @@ class ModelInstance:
 
     def invoke_llm(
         self,
-        prompt_messages: list[PromptMessage],
+        prompt_messages: Sequence[PromptMessage],
         model_parameters: Optional[dict] = None,
         tools: Sequence[PromptMessageTool] | None = None,
-        stop: Optional[list[str]] = None,
+        stop: Optional[Sequence[str]] = None,
         stream: bool = True,
         user: Optional[str] = None,
         callbacks: Optional[list[Callback]] = None,

+ 5 - 4
api/core/model_runtime/callbacks/base_callback.py

@@ -1,4 +1,5 @@
 from abc import ABC, abstractmethod
+from collections.abc import Sequence
 from typing import Optional
 
 from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
@@ -31,7 +32,7 @@ class Callback(ABC):
         prompt_messages: list[PromptMessage],
         model_parameters: dict,
         tools: Optional[list[PromptMessageTool]] = None,
-        stop: Optional[list[str]] = None,
+        stop: Optional[Sequence[str]] = None,
         stream: bool = True,
         user: Optional[str] = None,
     ) -> None:
@@ -60,7 +61,7 @@ class Callback(ABC):
         prompt_messages: list[PromptMessage],
         model_parameters: dict,
         tools: Optional[list[PromptMessageTool]] = None,
-        stop: Optional[list[str]] = None,
+        stop: Optional[Sequence[str]] = None,
         stream: bool = True,
         user: Optional[str] = None,
     ):
@@ -90,7 +91,7 @@ class Callback(ABC):
         prompt_messages: list[PromptMessage],
         model_parameters: dict,
         tools: Optional[list[PromptMessageTool]] = None,
-        stop: Optional[list[str]] = None,
+        stop: Optional[Sequence[str]] = None,
         stream: bool = True,
         user: Optional[str] = None,
     ) -> None:
@@ -120,7 +121,7 @@ class Callback(ABC):
         prompt_messages: list[PromptMessage],
         model_parameters: dict,
         tools: Optional[list[PromptMessageTool]] = None,
-        stop: Optional[list[str]] = None,
+        stop: Optional[Sequence[str]] = None,
         stream: bool = True,
         user: Optional[str] = None,
     ) -> None:

+ 2 - 0
api/core/model_runtime/entities/__init__.py

@@ -2,6 +2,7 @@ from .llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsa
 from .message_entities import (
     AssistantPromptMessage,
     AudioPromptMessageContent,
+    DocumentPromptMessageContent,
     ImagePromptMessageContent,
     PromptMessage,
     PromptMessageContent,
@@ -37,4 +38,5 @@ __all__ = [
     "LLMResultChunk",
     "LLMResultChunkDelta",
     "AudioPromptMessageContent",
+    "DocumentPromptMessageContent",
 ]

+ 11 - 2
api/core/model_runtime/entities/message_entities.py

@@ -1,6 +1,7 @@
 from abc import ABC
+from collections.abc import Sequence
 from enum import Enum
-from typing import Optional
+from typing import Literal, Optional
 
 from pydantic import BaseModel, Field, field_validator
 
@@ -57,6 +58,7 @@ class PromptMessageContentType(Enum):
     IMAGE = "image"
     AUDIO = "audio"
     VIDEO = "video"
+    DOCUMENT = "document"
 
 
 class PromptMessageContent(BaseModel):
@@ -101,13 +103,20 @@ class ImagePromptMessageContent(PromptMessageContent):
     detail: DETAIL = DETAIL.LOW
 
 
+class DocumentPromptMessageContent(PromptMessageContent):
+    type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
+    encode_format: Literal["base64"]
+    mime_type: str
+    data: str
+
+
 class PromptMessage(ABC, BaseModel):
     """
     Model class for prompt message.
     """
 
     role: PromptMessageRole
-    content: Optional[str | list[PromptMessageContent]] = None
+    content: Optional[str | Sequence[PromptMessageContent]] = None
     name: Optional[str] = None
 
     def is_empty(self) -> bool:

+ 3 - 0
api/core/model_runtime/entities/model_entities.py

@@ -87,6 +87,9 @@ class ModelFeature(Enum):
     AGENT_THOUGHT = "agent-thought"
     VISION = "vision"
     STREAM_TOOL_CALL = "stream-tool-call"
+    DOCUMENT = "document"
+    VIDEO = "video"
+    AUDIO = "audio"
 
 
 class DefaultParameterName(str, Enum):

+ 10 - 10
api/core/model_runtime/model_providers/__base/large_language_model.py

@@ -2,7 +2,7 @@ import logging
 import re
 import time
 from abc import abstractmethod
-from collections.abc import Generator, Mapping
+from collections.abc import Generator, Mapping, Sequence
 from typing import Optional, Union
 
 from pydantic import ConfigDict
@@ -48,7 +48,7 @@ class LargeLanguageModel(AIModel):
         prompt_messages: list[PromptMessage],
         model_parameters: Optional[dict] = None,
         tools: Optional[list[PromptMessageTool]] = None,
-        stop: Optional[list[str]] = None,
+        stop: Optional[Sequence[str]] = None,
         stream: bool = True,
         user: Optional[str] = None,
         callbacks: Optional[list[Callback]] = None,
@@ -169,7 +169,7 @@ class LargeLanguageModel(AIModel):
         prompt_messages: list[PromptMessage],
         model_parameters: dict,
         tools: Optional[list[PromptMessageTool]] = None,
-        stop: Optional[list[str]] = None,
+        stop: Optional[Sequence[str]] = None,
         stream: bool = True,
         user: Optional[str] = None,
         callbacks: Optional[list[Callback]] = None,
@@ -212,7 +212,7 @@ if you are not sure about the structure.
             )
 
         model_parameters.pop("response_format")
-        stop = stop or []
+        stop = list(stop) if stop is not None else []
         stop.extend(["\n```", "```\n"])
         block_prompts = block_prompts.replace("{{block}}", code_block)
 
@@ -408,7 +408,7 @@ if you are not sure about the structure.
         prompt_messages: list[PromptMessage],
         model_parameters: dict,
         tools: Optional[list[PromptMessageTool]] = None,
-        stop: Optional[list[str]] = None,
+        stop: Optional[Sequence[str]] = None,
         stream: bool = True,
         user: Optional[str] = None,
         callbacks: Optional[list[Callback]] = None,
@@ -479,7 +479,7 @@ if you are not sure about the structure.
         prompt_messages: list[PromptMessage],
         model_parameters: dict,
         tools: Optional[list[PromptMessageTool]] = None,
-        stop: Optional[list[str]] = None,
+        stop: Optional[Sequence[str]] = None,
         stream: bool = True,
         user: Optional[str] = None,
     ) -> Union[LLMResult, Generator]:
@@ -601,7 +601,7 @@ if you are not sure about the structure.
         prompt_messages: list[PromptMessage],
         model_parameters: dict,
         tools: Optional[list[PromptMessageTool]] = None,
-        stop: Optional[list[str]] = None,
+        stop: Optional[Sequence[str]] = None,
         stream: bool = True,
         user: Optional[str] = None,
         callbacks: Optional[list[Callback]] = None,
@@ -647,7 +647,7 @@ if you are not sure about the structure.
         prompt_messages: list[PromptMessage],
         model_parameters: dict,
         tools: Optional[list[PromptMessageTool]] = None,
-        stop: Optional[list[str]] = None,
+        stop: Optional[Sequence[str]] = None,
         stream: bool = True,
         user: Optional[str] = None,
         callbacks: Optional[list[Callback]] = None,
@@ -694,7 +694,7 @@ if you are not sure about the structure.
         prompt_messages: list[PromptMessage],
         model_parameters: dict,
         tools: Optional[list[PromptMessageTool]] = None,
-        stop: Optional[list[str]] = None,
+        stop: Optional[Sequence[str]] = None,
         stream: bool = True,
         user: Optional[str] = None,
         callbacks: Optional[list[Callback]] = None,
@@ -742,7 +742,7 @@ if you are not sure about the structure.
         prompt_messages: list[PromptMessage],
         model_parameters: dict,
         tools: Optional[list[PromptMessageTool]] = None,
-        stop: Optional[list[str]] = None,
+        stop: Optional[Sequence[str]] = None,
         stream: bool = True,
         user: Optional[str] = None,
         callbacks: Optional[list[Callback]] = None,

+ 1 - 0
api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20240620.yaml

@@ -7,6 +7,7 @@ features:
   - vision
   - tool-call
   - stream-tool-call
+  - document
 model_properties:
   mode: chat
   context_size: 200000

+ 1 - 0
api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml

@@ -7,6 +7,7 @@ features:
   - vision
   - tool-call
   - stream-tool-call
+  - document
 model_properties:
   mode: chat
   context_size: 200000

+ 30 - 6
api/core/model_runtime/model_providers/anthropic/llm/llm.py

@@ -1,7 +1,7 @@
 import base64
 import io
 import json
-from collections.abc import Generator
+from collections.abc import Generator, Sequence
 from typing import Optional, Union, cast
 
 import anthropic
@@ -21,9 +21,9 @@ from httpx import Timeout
 from PIL import Image
 
 from core.model_runtime.callbacks.base_callback import Callback
-from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
-from core.model_runtime.entities.message_entities import (
+from core.model_runtime.entities import (
     AssistantPromptMessage,
+    DocumentPromptMessageContent,
     ImagePromptMessageContent,
     PromptMessage,
     PromptMessageContentType,
@@ -33,6 +33,7 @@ from core.model_runtime.entities.message_entities import (
     ToolPromptMessage,
     UserPromptMessage,
 )
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
 from core.model_runtime.errors.invoke import (
     InvokeAuthorizationError,
     InvokeBadRequestError,
@@ -86,10 +87,10 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
         self,
         model: str,
         credentials: dict,
-        prompt_messages: list[PromptMessage],
+        prompt_messages: Sequence[PromptMessage],
         model_parameters: dict,
         tools: Optional[list[PromptMessageTool]] = None,
-        stop: Optional[list[str]] = None,
+        stop: Optional[Sequence[str]] = None,
         stream: bool = True,
         user: Optional[str] = None,
     ) -> Union[LLMResult, Generator]:
@@ -130,9 +131,17 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
         # Add the new header for claude-3-5-sonnet-20240620 model
         extra_headers = {}
         if model == "claude-3-5-sonnet-20240620":
-            if model_parameters.get("max_tokens") > 4096:
+            if model_parameters.get("max_tokens", 0) > 4096:
                 extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15"
 
+        if any(
+            isinstance(content, DocumentPromptMessageContent)
+            for prompt_message in prompt_messages
+            if isinstance(prompt_message.content, list)
+            for content in prompt_message.content
+        ):
+            extra_headers["anthropic-beta"] = "pdfs-2024-09-25"
+
         if tools:
             extra_model_kwargs["tools"] = [self._transform_tool_prompt(tool) for tool in tools]
             response = client.beta.tools.messages.create(
@@ -504,6 +513,21 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
                                     "source": {"type": "base64", "media_type": mime_type, "data": base64_data},
                                 }
                                 sub_messages.append(sub_message_dict)
+                            elif isinstance(message_content, DocumentPromptMessageContent):
+                                if message_content.mime_type != "application/pdf":
+                                    raise ValueError(
+                                        f"Unsupported document type {message_content.mime_type}, "
+                                        "only support application/pdf"
+                                    )
+                                sub_message_dict = {
+                                    "type": "document",
+                                    "source": {
+                                        "type": message_content.encode_format,
+                                        "media_type": message_content.mime_type,
+                                        "data": message_content.data,
+                                    },
+                                }
+                                sub_messages.append(sub_message_dict)
                         prompt_message_dicts.append({"role": "user", "content": sub_messages})
                 elif isinstance(message, AssistantPromptMessage):
                     message = cast(AssistantPromptMessage, message)

+ 1 - 0
api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml

@@ -7,6 +7,7 @@ features:
   - multi-tool-call
   - agent-thought
   - stream-tool-call
+  - audio
 model_properties:
   mode: chat
   context_size: 128000

+ 2 - 1
api/core/prompt/utils/prompt_message_util.py

@@ -1,3 +1,4 @@
+from collections.abc import Sequence
 from typing import cast
 
 from core.model_runtime.entities import (
@@ -14,7 +15,7 @@ from core.prompt.simple_prompt_transform import ModelMode
 
 class PromptMessageUtil:
     @staticmethod
-    def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: list[PromptMessage]) -> list[dict]:
+    def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Sequence[PromptMessage]) -> list[dict]:
         """
         Prompt messages to prompt for saving.
         :param model_mode: model mode

+ 10 - 2
api/core/variables/segments.py

@@ -118,11 +118,11 @@ class FileSegment(Segment):
 
     @property
     def log(self) -> str:
-        return str(self.value)
+        return ""
 
     @property
     def text(self) -> str:
-        return str(self.value)
+        return ""
 
 
 class ArrayAnySegment(ArraySegment):
@@ -155,3 +155,11 @@ class ArrayFileSegment(ArraySegment):
         for item in self.value:
             items.append(item.markdown)
         return "\n".join(items)
+
+    @property
+    def log(self) -> str:
+        return ""
+
+    @property
+    def text(self) -> str:
+        return ""

+ 16 - 2
api/core/workflow/nodes/llm/entities.py

@@ -39,7 +39,14 @@ class VisionConfig(BaseModel):
 
 
 class PromptConfig(BaseModel):
-    jinja2_variables: Optional[list[VariableSelector]] = None
+    jinja2_variables: Sequence[VariableSelector] = Field(default_factory=list)
+
+    @field_validator("jinja2_variables", mode="before")
+    @classmethod
+    def convert_none_jinja2_variables(cls, v: Any):
+        if v is None:
+            return []
+        return v
 
 
 class LLMNodeChatModelMessage(ChatModelMessage):
@@ -53,7 +60,14 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
 class LLMNodeData(BaseNodeData):
     model: ModelConfig
     prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
-    prompt_config: Optional[PromptConfig] = None
+    prompt_config: PromptConfig = Field(default_factory=PromptConfig)
     memory: Optional[MemoryConfig] = None
     context: ContextConfig
     vision: VisionConfig = Field(default_factory=VisionConfig)
+
+    @field_validator("prompt_config", mode="before")
+    @classmethod
+    def convert_none_prompt_config(cls, v: Any):
+        if v is None:
+            return PromptConfig()
+        return v

+ 8 - 0
api/core/workflow/nodes/llm/exc.py

@@ -24,3 +24,11 @@ class LLMModeRequiredError(LLMNodeError):
 
 class NoPromptFoundError(LLMNodeError):
     """Raised when no prompt is found in the LLM configuration."""
+
+
+class NotSupportedPromptTypeError(LLMNodeError):
+    """Raised when the prompt type is not supported."""
+
+
+class MemoryRolePrefixRequiredError(LLMNodeError):
+    """Raised when memory role prefix is required for completion model."""

+ 357 - 55
api/core/workflow/nodes/llm/node.py

@@ -1,4 +1,5 @@
 import json
+import logging
 from collections.abc import Generator, Mapping, Sequence
 from typing import TYPE_CHECKING, Any, Optional, cast
 
@@ -6,21 +7,26 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
 from core.entities.model_entities import ModelStatus
 from core.entities.provider_entities import QuotaUnit
 from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
+from core.file import FileType, file_manager
+from core.helper.code_executor import CodeExecutor, CodeLanguage
 from core.memory.token_buffer_memory import TokenBufferMemory
 from core.model_manager import ModelInstance, ModelManager
 from core.model_runtime.entities import (
-    AudioPromptMessageContent,
     ImagePromptMessageContent,
     PromptMessage,
     PromptMessageContentType,
     TextPromptMessageContent,
-    VideoPromptMessageContent,
 )
 from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
-from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.entities.message_entities import (
+    AssistantPromptMessage,
+    PromptMessageRole,
+    SystemPromptMessage,
+    UserPromptMessage,
+)
+from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.model_runtime.utils.encoders import jsonable_encoder
-from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
 from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
 from core.prompt.utils.prompt_message_util import PromptMessageUtil
 from core.variables import (
@@ -32,8 +38,9 @@ from core.variables import (
     ObjectSegment,
     StringSegment,
 )
-from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
 from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
+from core.workflow.entities.variable_entities import VariableSelector
+from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.enums import SystemVariableKey
 from core.workflow.graph_engine.entities.event import InNodeEvent
 from core.workflow.nodes.base import BaseNode
@@ -62,14 +69,18 @@ from .exc import (
     InvalidVariableTypeError,
     LLMModeRequiredError,
     LLMNodeError,
+    MemoryRolePrefixRequiredError,
     ModelNotExistError,
     NoPromptFoundError,
+    NotSupportedPromptTypeError,
     VariableNotFoundError,
 )
 
 if TYPE_CHECKING:
     from core.file.models import File
 
+logger = logging.getLogger(__name__)
+
 
 class LLMNode(BaseNode[LLMNodeData]):
     _node_data_cls = LLMNodeData
@@ -123,17 +134,13 @@ class LLMNode(BaseNode[LLMNodeData]):
 
             # fetch prompt messages
             if self.node_data.memory:
-                query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
-                if not query:
-                    raise VariableNotFoundError("Query not found")
-                query = query.text
+                query = self.node_data.memory.query_prompt_template
             else:
                 query = None
 
             prompt_messages, stop = self._fetch_prompt_messages(
-                system_query=query,
-                inputs=inputs,
-                files=files,
+                user_query=query,
+                user_files=files,
                 context=context,
                 memory=memory,
                 model_config=model_config,
@@ -141,6 +148,8 @@ class LLMNode(BaseNode[LLMNodeData]):
                 memory_config=self.node_data.memory,
                 vision_enabled=self.node_data.vision.enabled,
                 vision_detail=self.node_data.vision.configs.detail,
+                variable_pool=self.graph_runtime_state.variable_pool,
+                jinja2_variables=self.node_data.prompt_config.jinja2_variables,
             )
 
             process_data = {
@@ -181,6 +190,17 @@ class LLMNode(BaseNode[LLMNodeData]):
                 )
             )
             return
+        except Exception as e:
+            logger.exception(f"Node {self.node_id} failed to run")
+            yield RunCompletedEvent(
+                run_result=NodeRunResult(
+                    status=WorkflowNodeExecutionStatus.FAILED,
+                    error=str(e),
+                    inputs=node_inputs,
+                    process_data=process_data,
+                )
+            )
+            return
 
         outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
 
@@ -203,8 +223,8 @@ class LLMNode(BaseNode[LLMNodeData]):
         self,
         node_data_model: ModelConfig,
         model_instance: ModelInstance,
-        prompt_messages: list[PromptMessage],
-        stop: Optional[list[str]] = None,
+        prompt_messages: Sequence[PromptMessage],
+        stop: Optional[Sequence[str]] = None,
     ) -> Generator[NodeEvent, None, None]:
         db.session.close()
 
@@ -519,9 +539,8 @@ class LLMNode(BaseNode[LLMNodeData]):
     def _fetch_prompt_messages(
         self,
         *,
-        system_query: str | None = None,
-        inputs: dict[str, str] | None = None,
-        files: Sequence["File"],
+        user_query: str | None = None,
+        user_files: Sequence["File"],
         context: str | None = None,
         memory: TokenBufferMemory | None = None,
         model_config: ModelConfigWithCredentialsEntity,
@@ -529,58 +548,146 @@ class LLMNode(BaseNode[LLMNodeData]):
         memory_config: MemoryConfig | None = None,
         vision_enabled: bool = False,
         vision_detail: ImagePromptMessageContent.DETAIL,
-    ) -> tuple[list[PromptMessage], Optional[list[str]]]:
-        inputs = inputs or {}
-
-        prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
-        prompt_messages = prompt_transform.get_prompt(
-            prompt_template=prompt_template,
-            inputs=inputs,
-            query=system_query or "",
-            files=files,
-            context=context,
-            memory_config=memory_config,
-            memory=memory,
-            model_config=model_config,
-        )
-        stop = model_config.stop
+        variable_pool: VariablePool,
+        jinja2_variables: Sequence[VariableSelector],
+    ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
+        prompt_messages = []
+
+        if isinstance(prompt_template, list):
+            # For chat model
+            prompt_messages.extend(
+                _handle_list_messages(
+                    messages=prompt_template,
+                    context=context,
+                    jinja2_variables=jinja2_variables,
+                    variable_pool=variable_pool,
+                    vision_detail_config=vision_detail,
+                )
+            )
+
+            # Get memory messages for chat mode
+            memory_messages = _handle_memory_chat_mode(
+                memory=memory,
+                memory_config=memory_config,
+                model_config=model_config,
+            )
+            # Extend prompt_messages with memory messages
+            prompt_messages.extend(memory_messages)
+
+            # Add current query to the prompt messages
+            if user_query:
+                message = LLMNodeChatModelMessage(
+                    text=user_query,
+                    role=PromptMessageRole.USER,
+                    edition_type="basic",
+                )
+                prompt_messages.extend(
+                    _handle_list_messages(
+                        messages=[message],
+                        context="",
+                        jinja2_variables=[],
+                        variable_pool=variable_pool,
+                        vision_detail_config=vision_detail,
+                    )
+                )
+
+        elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
+            # For completion model
+            prompt_messages.extend(
+                _handle_completion_template(
+                    template=prompt_template,
+                    context=context,
+                    jinja2_variables=jinja2_variables,
+                    variable_pool=variable_pool,
+                )
+            )
+
+            # Get memory text for completion model
+            memory_text = _handle_memory_completion_mode(
+                memory=memory,
+                memory_config=memory_config,
+                model_config=model_config,
+            )
+            # Insert histories into the prompt
+            prompt_content = prompt_messages[0].content
+            if "#histories#" in prompt_content:
+                prompt_content = prompt_content.replace("#histories#", memory_text)
+            else:
+                prompt_content = memory_text + "\n" + prompt_content
+            prompt_messages[0].content = prompt_content
+
+            # Add current query to the prompt message
+            if user_query:
+                prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query)
+                prompt_messages[0].content = prompt_content
+        else:
+            errmsg = f"Prompt type {type(prompt_template)} is not supported"
+            logger.warning(errmsg)
+            raise NotSupportedPromptTypeError(errmsg)
+
+        if vision_enabled and user_files:
+            file_prompts = []
+            for file in user_files:
+                file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
+                file_prompts.append(file_prompt)
+            if (
+                len(prompt_messages) > 0
+                and isinstance(prompt_messages[-1], UserPromptMessage)
+                and isinstance(prompt_messages[-1].content, list)
+            ):
+                prompt_messages[-1] = UserPromptMessage(content=prompt_messages[-1].content + file_prompts)
+            else:
+                prompt_messages.append(UserPromptMessage(content=file_prompts))
+
+        # Filter prompt messages
         filtered_prompt_messages = []
         for prompt_message in prompt_messages:
-            if prompt_message.is_empty():
-                continue
-
-            if not isinstance(prompt_message.content, str):
+            if isinstance(prompt_message.content, list):
                 prompt_message_content = []
-                for content_item in prompt_message.content or []:
-                    # Skip image if vision is disabled
-                    if not vision_enabled and content_item.type == PromptMessageContentType.IMAGE:
+                for content_item in prompt_message.content:
+                    # Skip content if features are not defined
+                    if not model_config.model_schema.features:
+                        if content_item.type != PromptMessageContentType.TEXT:
+                            continue
+                        prompt_message_content.append(content_item)
                         continue
 
-                    if isinstance(content_item, ImagePromptMessageContent):
-                        # Override vision config if LLM node has vision config,
-                        # cuz vision detail is related to the configuration from FileUpload feature.
-                        content_item.detail = vision_detail
-                        prompt_message_content.append(content_item)
-                    elif isinstance(
-                        content_item, TextPromptMessageContent | AudioPromptMessageContent | VideoPromptMessageContent
+                    # Skip content if corresponding feature is not supported
+                    if (
+                        (
+                            content_item.type == PromptMessageContentType.IMAGE
+                            and ModelFeature.VISION not in model_config.model_schema.features
+                        )
+                        or (
+                            content_item.type == PromptMessageContentType.DOCUMENT
+                            and ModelFeature.DOCUMENT not in model_config.model_schema.features
+                        )
+                        or (
+                            content_item.type == PromptMessageContentType.VIDEO
+                            and ModelFeature.VIDEO not in model_config.model_schema.features
+                        )
+                        or (
+                            content_item.type == PromptMessageContentType.AUDIO
+                            and ModelFeature.AUDIO not in model_config.model_schema.features
+                        )
                     ):
-                        prompt_message_content.append(content_item)
-
-                if len(prompt_message_content) > 1:
-                    prompt_message.content = prompt_message_content
-                elif (
-                    len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT
-                ):
+                        continue
+                    prompt_message_content.append(content_item)
+                if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
                     prompt_message.content = prompt_message_content[0].data
-
+                else:
+                    prompt_message.content = prompt_message_content
+            if prompt_message.is_empty():
+                continue
             filtered_prompt_messages.append(prompt_message)
 
-        if not filtered_prompt_messages:
+        if len(filtered_prompt_messages) == 0:
             raise NoPromptFoundError(
                 "No prompt found in the LLM configuration. "
                 "Please ensure a prompt is properly configured before proceeding."
             )
 
+        stop = model_config.stop
         return filtered_prompt_messages, stop
 
     @classmethod
@@ -715,3 +822,198 @@ class LLMNode(BaseNode[LLMNodeData]):
                 }
             },
         }
+
+
+def _combine_text_message_with_role(*, text: str, role: PromptMessageRole):
+    match role:
+        case PromptMessageRole.USER:
+            return UserPromptMessage(content=[TextPromptMessageContent(data=text)])
+        case PromptMessageRole.ASSISTANT:
+            return AssistantPromptMessage(content=[TextPromptMessageContent(data=text)])
+        case PromptMessageRole.SYSTEM:
+            return SystemPromptMessage(content=[TextPromptMessageContent(data=text)])
+    raise NotImplementedError(f"Role {role} is not supported")
+
+
+def _render_jinja2_message(
+    *,
+    template: str,
+    jinjia2_variables: Sequence[VariableSelector],
+    variable_pool: VariablePool,
+):
+    if not template:
+        return ""
+
+    jinjia2_inputs = {}
+    for jinja2_variable in jinjia2_variables:
+        variable = variable_pool.get(jinja2_variable.value_selector)
+        jinjia2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
+    code_execute_resp = CodeExecutor.execute_workflow_code_template(
+        language=CodeLanguage.JINJA2,
+        code=template,
+        inputs=jinjia2_inputs,
+    )
+    result_text = code_execute_resp["result"]
+    return result_text
+
+
+def _handle_list_messages(
+    *,
+    messages: Sequence[LLMNodeChatModelMessage],
+    context: Optional[str],
+    jinja2_variables: Sequence[VariableSelector],
+    variable_pool: VariablePool,
+    vision_detail_config: ImagePromptMessageContent.DETAIL,
+) -> Sequence[PromptMessage]:
+    prompt_messages = []
+    for message in messages:
+        if message.edition_type == "jinja2":
+            result_text = _render_jinja2_message(
+                template=message.jinja2_text or "",
+                jinjia2_variables=jinja2_variables,
+                variable_pool=variable_pool,
+            )
+            prompt_message = _combine_text_message_with_role(text=result_text, role=message.role)
+            prompt_messages.append(prompt_message)
+        else:
+            # Get segment group from basic message
+            if context:
+                template = message.text.replace("{#context#}", context)
+            else:
+                template = message.text
+            segment_group = variable_pool.convert_template(template)
+
+            # Process segments for images
+            file_contents = []
+            for segment in segment_group.value:
+                if isinstance(segment, ArrayFileSegment):
+                    for file in segment.value:
+                        if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
+                            file_content = file_manager.to_prompt_message_content(
+                                file, image_detail_config=vision_detail_config
+                            )
+                            file_contents.append(file_content)
+                if isinstance(segment, FileSegment):
+                    file = segment.value
+                    if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
+                        file_content = file_manager.to_prompt_message_content(
+                            file, image_detail_config=vision_detail_config
+                        )
+                        file_contents.append(file_content)
+
+            # Create message with text from all segments
+            plain_text = segment_group.text
+            if plain_text:
+                prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role)
+                prompt_messages.append(prompt_message)
+
+            if file_contents:
+                # Create message with image contents
+                prompt_message = UserPromptMessage(content=file_contents)
+                prompt_messages.append(prompt_message)
+
+    return prompt_messages
+
+
+def _calculate_rest_token(
+    *, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
+) -> int:
+    rest_tokens = 2000
+
+    model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
+    if model_context_tokens:
+        model_instance = ModelInstance(
+            provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
+        )
+
+        curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
+
+        max_tokens = 0
+        for parameter_rule in model_config.model_schema.parameter_rules:
+            if parameter_rule.name == "max_tokens" or (
+                parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
+            ):
+                max_tokens = (
+                    model_config.parameters.get(parameter_rule.name)
+                    or model_config.parameters.get(str(parameter_rule.use_template))
+                    or 0
+                )
+
+        rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
+        rest_tokens = max(rest_tokens, 0)
+
+    return rest_tokens
+
+
+def _handle_memory_chat_mode(
+    *,
+    memory: TokenBufferMemory | None,
+    memory_config: MemoryConfig | None,
+    model_config: ModelConfigWithCredentialsEntity,
+) -> Sequence[PromptMessage]:
+    memory_messages = []
+    # Get messages from memory for chat model
+    if memory and memory_config:
+        rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
+        memory_messages = memory.get_history_prompt_messages(
+            max_token_limit=rest_tokens,
+            message_limit=memory_config.window.size if memory_config.window.enabled else None,
+        )
+    return memory_messages
+
+
+def _handle_memory_completion_mode(
+    *,
+    memory: TokenBufferMemory | None,
+    memory_config: MemoryConfig | None,
+    model_config: ModelConfigWithCredentialsEntity,
+) -> str:
+    memory_text = ""
+    # Get history text from memory for completion model
+    if memory and memory_config:
+        rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
+        if not memory_config.role_prefix:
+            raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
+        memory_text = memory.get_history_prompt_text(
+            max_token_limit=rest_tokens,
+            message_limit=memory_config.window.size if memory_config.window.enabled else None,
+            human_prefix=memory_config.role_prefix.user,
+            ai_prefix=memory_config.role_prefix.assistant,
+        )
+    return memory_text
+
+
+def _handle_completion_template(
+    *,
+    template: LLMNodeCompletionModelPromptTemplate,
+    context: Optional[str],
+    jinja2_variables: Sequence[VariableSelector],
+    variable_pool: VariablePool,
+) -> Sequence[PromptMessage]:
+    """Handle completion template processing outside of LLMNode class.
+
+    Args:
+        template: The completion model prompt template
+        context: Optional context string
+        jinja2_variables: Variables for jinja2 template rendering
+        variable_pool: Variable pool for template conversion
+
+    Returns:
+        Sequence of prompt messages
+    """
+    prompt_messages = []
+    if template.edition_type == "jinja2":
+        result_text = _render_jinja2_message(
+            template=template.jinja2_text or "",
+            jinjia2_variables=jinja2_variables,
+            variable_pool=variable_pool,
+        )
+    else:
+        if context:
+            template_text = template.text.replace("{#context#}", context)
+        else:
+            template_text = template.text
+        result_text = variable_pool.convert_template(template_text).text
+    prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER)
+    prompt_messages.append(prompt_message)
+    return prompt_messages

+ 4 - 2
api/core/workflow/nodes/question_classifier/question_classifier_node.py

@@ -86,12 +86,14 @@ class QuestionClassifierNode(LLMNode):
         )
         prompt_messages, stop = self._fetch_prompt_messages(
             prompt_template=prompt_template,
-            system_query=query,
+            user_query=query,
             memory=memory,
             model_config=model_config,
-            files=files,
+            user_files=files,
             vision_enabled=node_data.vision.enabled,
             vision_detail=node_data.vision.configs.detail,
+            variable_pool=variable_pool,
+            jinja2_variables=[],
         )
 
         # handle invoke result

+ 16 - 1
api/poetry.lock

@@ -2423,6 +2423,21 @@ files = [
 [package.extras]
 test = ["pytest (>=6)"]
 
+[[package]]
+name = "faker"
+version = "32.1.0"
+description = "Faker is a Python package that generates fake data for you."
+optional = false
+python-versions = ">=3.8"
+files = [
+    {file = "Faker-32.1.0-py3-none-any.whl", hash = "sha256:c77522577863c264bdc9dad3a2a750ad3f7ee43ff8185072e482992288898814"},
+    {file = "faker-32.1.0.tar.gz", hash = "sha256:aac536ba04e6b7beb2332c67df78485fc29c1880ff723beac6d1efd45e2f10f5"},
+]
+
+[package.dependencies]
+python-dateutil = ">=2.4"
+typing-extensions = "*"
+
 [[package]]
 name = "fal-client"
 version = "0.5.6"
@@ -11041,4 +11056,4 @@ cffi = ["cffi (>=1.11)"]
 [metadata]
 lock-version = "2.0"
 python-versions = ">=3.10,<3.13"
-content-hash = "69a3f471f85dce9e5fb889f739e148a4a6d95aaf94081414503867c7157dba69"
+content-hash = "d149b24ce7a203fa93eddbe8430d8ea7e5160a89c8d348b1b747c19899065639"

+ 1 - 0
api/pyproject.toml

@@ -268,6 +268,7 @@ weaviate-client = "~3.21.0"
 optional = true
 [tool.poetry.group.dev.dependencies]
 coverage = "~7.2.4"
+faker = "~32.1.0"
 pytest = "~8.3.2"
 pytest-benchmark = "~4.0.0"
 pytest-env = "~1.1.3"

+ 0 - 1
api/tests/integration_tests/model_runtime/azure_ai_studio/test_llm.py

@@ -11,7 +11,6 @@ from core.model_runtime.entities.message_entities import (
 )
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.model_providers.azure_ai_studio.llm.llm import AzureAIStudioLargeLanguageModel
-from tests.integration_tests.model_runtime.__mock.azure_ai_studio import setup_azure_ai_studio_mock
 
 
 @pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True)

+ 3 - 11
api/tests/integration_tests/model_runtime/azure_ai_studio/test_rerank.py

@@ -4,29 +4,21 @@ import pytest
 
 from core.model_runtime.entities.rerank_entities import RerankResult
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
-from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureAIStudioRerankModel
+from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureRerankModel
 
 
 def test_validate_credentials():
-    model = AzureAIStudioRerankModel()
+    model = AzureRerankModel()
 
     with pytest.raises(CredentialsValidateFailedError):
         model.validate_credentials(
             model="azure-ai-studio-rerank-v1",
             credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")},
-            query="What is the capital of the United States?",
-            docs=[
-                "Carson City is the capital city of the American state of Nevada. At the 2010 United States "
-                "Census, Carson City had a population of 55,274.",
-                "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
-                "are a political division controlled by the United States. Its capital is Saipan.",
-            ],
-            score_threshold=0.8,
         )
 
 
 def test_invoke_model():
-    model = AzureAIStudioRerankModel()
+    model = AzureRerankModel()
 
     result = model.invoke(
         model="azure-ai-studio-rerank-v1",

+ 455 - 96
api/tests/unit_tests/core/workflow/nodes/llm/test_node.py

@@ -1,125 +1,484 @@
+from collections.abc import Sequence
+from typing import Optional
+
 import pytest
 
-from core.app.entities.app_invoke_entities import InvokeFrom
+from configs import dify_config
+from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
+from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
+from core.entities.provider_entities import CustomConfiguration, SystemConfiguration
 from core.file import File, FileTransferMethod, FileType
-from core.model_runtime.entities.message_entities import ImagePromptMessageContent
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.message_entities import (
+    AssistantPromptMessage,
+    ImagePromptMessageContent,
+    PromptMessage,
+    PromptMessageRole,
+    SystemPromptMessage,
+    TextPromptMessageContent,
+    UserPromptMessage,
+)
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel
+from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
+from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
+from core.prompt.entities.advanced_prompt_entities import MemoryConfig
 from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
 from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
 from core.workflow.nodes.answer import AnswerStreamGenerateRoute
 from core.workflow.nodes.end import EndStreamParam
-from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig, VisionConfig, VisionConfigOptions
+from core.workflow.nodes.llm.entities import (
+    ContextConfig,
+    LLMNodeChatModelMessage,
+    LLMNodeData,
+    ModelConfig,
+    VisionConfig,
+    VisionConfigOptions,
+)
 from core.workflow.nodes.llm.node import LLMNode
 from models.enums import UserFrom
+from models.provider import ProviderType
 from models.workflow import WorkflowType
+from tests.unit_tests.core.workflow.nodes.llm.test_scenarios import LLMNodeTestScenario
 
 
-class TestLLMNode:
-    @pytest.fixture
-    def llm_node(self):
-        data = LLMNodeData(
-            title="Test LLM",
-            model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
-            prompt_template=[],
-            memory=None,
-            context=ContextConfig(enabled=False),
-            vision=VisionConfig(
-                enabled=True,
-                configs=VisionConfigOptions(
-                    variable_selector=["sys", "files"],
-                    detail=ImagePromptMessageContent.DETAIL.HIGH,
-                ),
-            ),
-        )
-        variable_pool = VariablePool(
-            system_variables={},
-            user_inputs={},
-        )
-        node = LLMNode(
-            id="1",
-            config={
-                "id": "1",
-                "data": data.model_dump(),
-            },
-            graph_init_params=GraphInitParams(
-                tenant_id="1",
-                app_id="1",
-                workflow_type=WorkflowType.WORKFLOW,
-                workflow_id="1",
-                graph_config={},
-                user_id="1",
-                user_from=UserFrom.ACCOUNT,
-                invoke_from=InvokeFrom.SERVICE_API,
-                call_depth=0,
+class MockTokenBufferMemory:
+    def __init__(self, history_messages=None):
+        self.history_messages = history_messages or []
+
+    def get_history_prompt_messages(
+        self, max_token_limit: int = 2000, message_limit: Optional[int] = None
+    ) -> Sequence[PromptMessage]:
+        if message_limit is not None:
+            return self.history_messages[-message_limit * 2 :]
+        return self.history_messages
+
+
+@pytest.fixture
+def llm_node():
+    data = LLMNodeData(
+        title="Test LLM",
+        model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
+        prompt_template=[],
+        memory=None,
+        context=ContextConfig(enabled=False),
+        vision=VisionConfig(
+            enabled=True,
+            configs=VisionConfigOptions(
+                variable_selector=["sys", "files"],
+                detail=ImagePromptMessageContent.DETAIL.HIGH,
             ),
-            graph=Graph(
-                root_node_id="1",
-                answer_stream_generate_routes=AnswerStreamGenerateRoute(
-                    answer_dependencies={},
-                    answer_generate_route={},
-                ),
-                end_stream_param=EndStreamParam(
-                    end_dependencies={},
-                    end_stream_variable_selector_mapping={},
-                ),
+        ),
+    )
+    variable_pool = VariablePool(
+        system_variables={},
+        user_inputs={},
+    )
+    node = LLMNode(
+        id="1",
+        config={
+            "id": "1",
+            "data": data.model_dump(),
+        },
+        graph_init_params=GraphInitParams(
+            tenant_id="1",
+            app_id="1",
+            workflow_type=WorkflowType.WORKFLOW,
+            workflow_id="1",
+            graph_config={},
+            user_id="1",
+            user_from=UserFrom.ACCOUNT,
+            invoke_from=InvokeFrom.SERVICE_API,
+            call_depth=0,
+        ),
+        graph=Graph(
+            root_node_id="1",
+            answer_stream_generate_routes=AnswerStreamGenerateRoute(
+                answer_dependencies={},
+                answer_generate_route={},
             ),
-            graph_runtime_state=GraphRuntimeState(
-                variable_pool=variable_pool,
-                start_at=0,
+            end_stream_param=EndStreamParam(
+                end_dependencies={},
+                end_stream_variable_selector_mapping={},
             ),
-        )
-        return node
+        ),
+        graph_runtime_state=GraphRuntimeState(
+            variable_pool=variable_pool,
+            start_at=0,
+        ),
+    )
+    return node
+
+
+@pytest.fixture
+def model_config():
+    # Create actual provider and model type instances
+    model_provider_factory = ModelProviderFactory()
+    provider_instance = model_provider_factory.get_provider_instance("openai")
+    model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
+
+    # Create a ProviderModelBundle
+    provider_model_bundle = ProviderModelBundle(
+        configuration=ProviderConfiguration(
+            tenant_id="1",
+            provider=provider_instance.get_provider_schema(),
+            preferred_provider_type=ProviderType.CUSTOM,
+            using_provider_type=ProviderType.CUSTOM,
+            system_configuration=SystemConfiguration(enabled=False),
+            custom_configuration=CustomConfiguration(provider=None),
+            model_settings=[],
+        ),
+        provider_instance=provider_instance,
+        model_type_instance=model_type_instance,
+    )
 
-    def test_fetch_files_with_file_segment(self, llm_node):
-        file = File(
+    # Create and return a ModelConfigWithCredentialsEntity
+    return ModelConfigWithCredentialsEntity(
+        provider="openai",
+        model="gpt-3.5-turbo",
+        model_schema=AIModelEntity(
+            model="gpt-3.5-turbo",
+            label=I18nObject(en_US="GPT-3.5 Turbo"),
+            model_type=ModelType.LLM,
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_properties={},
+        ),
+        mode="chat",
+        credentials={},
+        parameters={},
+        provider_model_bundle=provider_model_bundle,
+    )
+
+
+def test_fetch_files_with_file_segment(llm_node):
+    file = File(
+        id="1",
+        tenant_id="test",
+        type=FileType.IMAGE,
+        filename="test.jpg",
+        transfer_method=FileTransferMethod.LOCAL_FILE,
+        related_id="1",
+    )
+    llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
+
+    result = llm_node._fetch_files(selector=["sys", "files"])
+    assert result == [file]
+
+
+def test_fetch_files_with_array_file_segment(llm_node):
+    files = [
+        File(
             id="1",
             tenant_id="test",
             type=FileType.IMAGE,
-            filename="test.jpg",
+            filename="test1.jpg",
             transfer_method=FileTransferMethod.LOCAL_FILE,
             related_id="1",
+        ),
+        File(
+            id="2",
+            tenant_id="test",
+            type=FileType.IMAGE,
+            filename="test2.jpg",
+            transfer_method=FileTransferMethod.LOCAL_FILE,
+            related_id="2",
+        ),
+    ]
+    llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
+
+    result = llm_node._fetch_files(selector=["sys", "files"])
+    assert result == files
+
+
+def test_fetch_files_with_none_segment(llm_node):
+    llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
+
+    result = llm_node._fetch_files(selector=["sys", "files"])
+    assert result == []
+
+
+def test_fetch_files_with_array_any_segment(llm_node):
+    llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
+
+    result = llm_node._fetch_files(selector=["sys", "files"])
+    assert result == []
+
+
+def test_fetch_files_with_non_existent_variable(llm_node):
+    result = llm_node._fetch_files(selector=["sys", "files"])
+    assert result == []
+
+
+def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
+    prompt_template = []
+    llm_node.node_data.prompt_template = prompt_template
+
+    fake_vision_detail = faker.random_element(
+        [ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW]
+    )
+    fake_remote_url = faker.url()
+    files = [
+        File(
+            id="1",
+            tenant_id="test",
+            type=FileType.IMAGE,
+            filename="test1.jpg",
+            transfer_method=FileTransferMethod.REMOTE_URL,
+            remote_url=fake_remote_url,
         )
-        llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
-
-        result = llm_node._fetch_files(selector=["sys", "files"])
-        assert result == [file]
-
-    def test_fetch_files_with_array_file_segment(self, llm_node):
-        files = [
-            File(
-                id="1",
-                tenant_id="test",
-                type=FileType.IMAGE,
-                filename="test1.jpg",
-                transfer_method=FileTransferMethod.LOCAL_FILE,
-                related_id="1",
-            ),
-            File(
-                id="2",
-                tenant_id="test",
-                type=FileType.IMAGE,
-                filename="test2.jpg",
-                transfer_method=FileTransferMethod.LOCAL_FILE,
-                related_id="2",
-            ),
-        ]
-        llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
+    ]
+
+    fake_query = faker.sentence()
+
+    prompt_messages, _ = llm_node._fetch_prompt_messages(
+        user_query=fake_query,
+        user_files=files,
+        context=None,
+        memory=None,
+        model_config=model_config,
+        prompt_template=prompt_template,
+        memory_config=None,
+        vision_enabled=False,
+        vision_detail=fake_vision_detail,
+        variable_pool=llm_node.graph_runtime_state.variable_pool,
+        jinja2_variables=[],
+    )
+
+    assert prompt_messages == [UserPromptMessage(content=fake_query)]
+
+
+def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
+    # Setup dify config
+    dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url"
+    dify_config.MULTIMODAL_SEND_VIDEO_FORMAT = "url"
+
+    # Generate fake values for prompt template
+    fake_assistant_prompt = faker.sentence()
+    fake_query = faker.sentence()
+    fake_context = faker.sentence()
+    fake_window_size = faker.random_int(min=1, max=3)
+    fake_vision_detail = faker.random_element(
+        [ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW]
+    )
+    fake_remote_url = faker.url()
+
+    # Setup mock memory with history messages
+    mock_history = [
+        UserPromptMessage(content=faker.sentence()),
+        AssistantPromptMessage(content=faker.sentence()),
+        UserPromptMessage(content=faker.sentence()),
+        AssistantPromptMessage(content=faker.sentence()),
+        UserPromptMessage(content=faker.sentence()),
+        AssistantPromptMessage(content=faker.sentence()),
+    ]
 
-        result = llm_node._fetch_files(selector=["sys", "files"])
-        assert result == files
+    # Setup memory configuration
+    memory_config = MemoryConfig(
+        role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"),
+        window=MemoryConfig.WindowConfig(enabled=True, size=fake_window_size),
+        query_prompt_template=None,
+    )
 
-    def test_fetch_files_with_none_segment(self, llm_node):
-        llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
+    memory = MockTokenBufferMemory(history_messages=mock_history)
 
-        result = llm_node._fetch_files(selector=["sys", "files"])
-        assert result == []
+    # Test scenarios covering different file input combinations
+    test_scenarios = [
+        LLMNodeTestScenario(
+            description="No files",
+            user_query=fake_query,
+            user_files=[],
+            features=[],
+            vision_enabled=False,
+            vision_detail=None,
+            window_size=fake_window_size,
+            prompt_template=[
+                LLMNodeChatModelMessage(
+                    text=fake_context,
+                    role=PromptMessageRole.SYSTEM,
+                    edition_type="basic",
+                ),
+                LLMNodeChatModelMessage(
+                    text="{#context#}",
+                    role=PromptMessageRole.USER,
+                    edition_type="basic",
+                ),
+                LLMNodeChatModelMessage(
+                    text=fake_assistant_prompt,
+                    role=PromptMessageRole.ASSISTANT,
+                    edition_type="basic",
+                ),
+            ],
+            expected_messages=[
+                SystemPromptMessage(content=fake_context),
+                UserPromptMessage(content=fake_context),
+                AssistantPromptMessage(content=fake_assistant_prompt),
+            ]
+            + mock_history[fake_window_size * -2 :]
+            + [
+                UserPromptMessage(content=fake_query),
+            ],
+        ),
+        LLMNodeTestScenario(
+            description="User files",
+            user_query=fake_query,
+            user_files=[
+                File(
+                    tenant_id="test",
+                    type=FileType.IMAGE,
+                    filename="test1.jpg",
+                    transfer_method=FileTransferMethod.REMOTE_URL,
+                    remote_url=fake_remote_url,
+                )
+            ],
+            vision_enabled=True,
+            vision_detail=fake_vision_detail,
+            features=[ModelFeature.VISION],
+            window_size=fake_window_size,
+            prompt_template=[
+                LLMNodeChatModelMessage(
+                    text=fake_context,
+                    role=PromptMessageRole.SYSTEM,
+                    edition_type="basic",
+                ),
+                LLMNodeChatModelMessage(
+                    text="{#context#}",
+                    role=PromptMessageRole.USER,
+                    edition_type="basic",
+                ),
+                LLMNodeChatModelMessage(
+                    text=fake_assistant_prompt,
+                    role=PromptMessageRole.ASSISTANT,
+                    edition_type="basic",
+                ),
+            ],
+            expected_messages=[
+                SystemPromptMessage(content=fake_context),
+                UserPromptMessage(content=fake_context),
+                AssistantPromptMessage(content=fake_assistant_prompt),
+            ]
+            + mock_history[fake_window_size * -2 :]
+            + [
+                UserPromptMessage(
+                    content=[
+                        TextPromptMessageContent(data=fake_query),
+                        ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
+                    ]
+                ),
+            ],
+        ),
+        LLMNodeTestScenario(
+            description="Prompt template with variable selector of File",
+            user_query=fake_query,
+            user_files=[],
+            vision_enabled=False,
+            vision_detail=fake_vision_detail,
+            features=[ModelFeature.VISION],
+            window_size=fake_window_size,
+            prompt_template=[
+                LLMNodeChatModelMessage(
+                    text="{{#input.image#}}",
+                    role=PromptMessageRole.USER,
+                    edition_type="basic",
+                ),
+            ],
+            expected_messages=[
+                UserPromptMessage(
+                    content=[
+                        ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
+                    ]
+                ),
+            ]
+            + mock_history[fake_window_size * -2 :]
+            + [UserPromptMessage(content=fake_query)],
+            file_variables={
+                "input.image": File(
+                    tenant_id="test",
+                    type=FileType.IMAGE,
+                    filename="test1.jpg",
+                    transfer_method=FileTransferMethod.REMOTE_URL,
+                    remote_url=fake_remote_url,
+                )
+            },
+        ),
+        LLMNodeTestScenario(
+            description="Prompt template with variable selector of File without vision feature",
+            user_query=fake_query,
+            user_files=[],
+            vision_enabled=True,
+            vision_detail=fake_vision_detail,
+            features=[],
+            window_size=fake_window_size,
+            prompt_template=[
+                LLMNodeChatModelMessage(
+                    text="{{#input.image#}}",
+                    role=PromptMessageRole.USER,
+                    edition_type="basic",
+                ),
+            ],
+            expected_messages=mock_history[fake_window_size * -2 :] + [UserPromptMessage(content=fake_query)],
+            file_variables={
+                "input.image": File(
+                    tenant_id="test",
+                    type=FileType.IMAGE,
+                    filename="test1.jpg",
+                    transfer_method=FileTransferMethod.REMOTE_URL,
+                    remote_url=fake_remote_url,
+                )
+            },
+        ),
+        LLMNodeTestScenario(
+            description="Prompt template with variable selector of File with video file and vision feature",
+            user_query=fake_query,
+            user_files=[],
+            vision_enabled=True,
+            vision_detail=fake_vision_detail,
+            features=[ModelFeature.VISION],
+            window_size=fake_window_size,
+            prompt_template=[
+                LLMNodeChatModelMessage(
+                    text="{{#input.image#}}",
+                    role=PromptMessageRole.USER,
+                    edition_type="basic",
+                ),
+            ],
+            expected_messages=mock_history[fake_window_size * -2 :] + [UserPromptMessage(content=fake_query)],
+            file_variables={
+                "input.image": File(
+                    tenant_id="test",
+                    type=FileType.VIDEO,
+                    filename="test1.mp4",
+                    transfer_method=FileTransferMethod.REMOTE_URL,
+                    remote_url=fake_remote_url,
+                    extension="mp4",
+                )
+            },
+        ),
+    ]
 
-    def test_fetch_files_with_array_any_segment(self, llm_node):
-        llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
+    for scenario in test_scenarios:
+        model_config.model_schema.features = scenario.features
 
-        result = llm_node._fetch_files(selector=["sys", "files"])
-        assert result == []
+        for k, v in scenario.file_variables.items():
+            selector = k.split(".")
+            llm_node.graph_runtime_state.variable_pool.add(selector, v)
+
+        # Call the method under test
+        prompt_messages, _ = llm_node._fetch_prompt_messages(
+            user_query=scenario.user_query,
+            user_files=scenario.user_files,
+            context=fake_context,
+            memory=memory,
+            model_config=model_config,
+            prompt_template=scenario.prompt_template,
+            memory_config=memory_config,
+            vision_enabled=scenario.vision_enabled,
+            vision_detail=scenario.vision_detail,
+            variable_pool=llm_node.graph_runtime_state.variable_pool,
+            jinja2_variables=[],
+        )
 
-    def test_fetch_files_with_non_existent_variable(self, llm_node):
-        result = llm_node._fetch_files(selector=["sys", "files"])
-        assert result == []
+        # Verify the result
+        assert len(prompt_messages) == len(scenario.expected_messages), f"Scenario failed: {scenario.description}"
+        assert (
+            prompt_messages == scenario.expected_messages
+        ), f"Message content mismatch in scenario: {scenario.description}"

+ 25 - 0
api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py

@@ -0,0 +1,25 @@
+from collections.abc import Mapping, Sequence
+
+from pydantic import BaseModel, Field
+
+from core.file import File
+from core.model_runtime.entities.message_entities import PromptMessage
+from core.model_runtime.entities.model_entities import ModelFeature
+from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage
+
+
+class LLMNodeTestScenario(BaseModel):
+    """Test scenario for LLM node testing."""
+
+    description: str = Field(..., description="Description of the test scenario")
+    user_query: str = Field(..., description="User query input")
+    user_files: Sequence[File] = Field(default_factory=list, description="List of user files")
+    vision_enabled: bool = Field(default=False, description="Whether vision is enabled")
+    vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled")
+    features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features")
+    window_size: int = Field(..., description="Window size for memory")
+    prompt_template: Sequence[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages")
+    file_variables: Mapping[str, File | Sequence[File]] = Field(
+        default_factory=dict, description="List of file variables"
+    )
+    expected_messages: Sequence[PromptMessage] = Field(..., description="Expected messages after processing")

+ 1 - 0
web/app/components/workflow/nodes/_base/components/editor/code-editor/editor-support-vars.tsx

@@ -160,6 +160,7 @@ const CodeEditor: FC<Props> = ({
             hideSearch
             vars={availableVars}
             onChange={handleSelectVar}
+            isSupportFileVar={false}
           />
         </div>
       )}

+ 3 - 0
web/app/components/workflow/nodes/_base/components/variable/var-list.tsx

@@ -18,6 +18,7 @@ type Props = {
   isSupportConstantValue?: boolean
   onlyLeafNodeVar?: boolean
   filterVar?: (payload: Var, valueSelector: ValueSelector) => boolean
+  isSupportFileVar?: boolean
 }
 
 const VarList: FC<Props> = ({
@@ -29,6 +30,7 @@ const VarList: FC<Props> = ({
   isSupportConstantValue,
   onlyLeafNodeVar,
   filterVar,
+  isSupportFileVar = true,
 }) => {
   const { t } = useTranslation()
 
@@ -94,6 +96,7 @@ const VarList: FC<Props> = ({
             defaultVarKindType={item.variable_type}
             onlyLeafNodeVar={onlyLeafNodeVar}
             filterVar={filterVar}
+            isSupportFileVar={isSupportFileVar}
           />
           {!readonly && (
             <RemoveButton

+ 3 - 0
web/app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx

@@ -59,6 +59,7 @@ type Props = {
   isInTable?: boolean
   onRemove?: () => void
   typePlaceHolder?: string
+  isSupportFileVar?: boolean
 }
 
 const VarReferencePicker: FC<Props> = ({
@@ -81,6 +82,7 @@ const VarReferencePicker: FC<Props> = ({
   isInTable,
   onRemove,
   typePlaceHolder,
+  isSupportFileVar = true,
 }) => {
   const { t } = useTranslation()
   const store = useStoreApi()
@@ -382,6 +384,7 @@ const VarReferencePicker: FC<Props> = ({
               vars={outputVars}
               onChange={handleVarReferenceChange}
               itemWidth={isAddBtnTrigger ? 260 : triggerWidth}
+              isSupportFileVar={isSupportFileVar}
             />
           )}
         </PortalToFollowElemContent>

+ 3 - 1
web/app/components/workflow/nodes/_base/components/variable/var-reference-popup.tsx

@@ -8,11 +8,13 @@ type Props = {
   vars: NodeOutPutVar[]
   onChange: (value: ValueSelector, varDetail: Var) => void
   itemWidth?: number
+  isSupportFileVar?: boolean
 }
 const VarReferencePopup: FC<Props> = ({
   vars,
   onChange,
   itemWidth,
+  isSupportFileVar = true,
 }) => {
   // max-h-[300px] overflow-y-auto todo: use portal to handle long list
   return (
@@ -24,7 +26,7 @@ const VarReferencePopup: FC<Props> = ({
         vars={vars}
         onChange={onChange}
         itemWidth={itemWidth}
-        isSupportFileVar
+        isSupportFileVar={isSupportFileVar}
       />
     </div >
   )

+ 1 - 0
web/app/components/workflow/nodes/code/panel.tsx

@@ -89,6 +89,7 @@ const Panel: FC<NodePanelProps<CodeNodeType>> = ({
             list={inputs.variables}
             onChange={handleVarListChange}
             filterVar={filterVar}
+            isSupportFileVar={false}
           />
         </Field>
         <Split />

+ 1 - 0
web/app/components/workflow/nodes/llm/components/config-prompt-item.tsx

@@ -144,6 +144,7 @@ const ConfigPromptItem: FC<Props> = ({
       onEditionTypeChange={onEditionTypeChange}
       varList={varList}
       handleAddVariable={handleAddVariable}
+      isSupportFileVar
     />
   )
 }

+ 4 - 1
web/app/components/workflow/nodes/llm/panel.tsx

@@ -67,6 +67,7 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
     handleStop,
     varInputs,
     runResult,
+    filterJinjia2InputVar,
   } = useConfig(id, data)
 
   const model = inputs.model
@@ -194,7 +195,8 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
               list={inputs.prompt_config?.jinja2_variables || []}
               onChange={handleVarListChange}
               onVarNameChange={handleVarNameChange}
-              filterVar={filterVar}
+              filterVar={filterJinjia2InputVar}
+              isSupportFileVar={false}
             />
           </Field>
         )}
@@ -233,6 +235,7 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
                 hasSetBlockStatus={hasSetBlockStatus}
                 nodesOutputVars={availableVars}
                 availableNodes={availableNodesWithParent}
+                isSupportFileVar
               />
 
               {inputs.memory.query_prompt_template && !inputs.memory.query_prompt_template.includes('{{#sys.query#}}') && (

+ 6 - 1
web/app/components/workflow/nodes/llm/use-config.ts

@@ -278,11 +278,15 @@ const useConfig = (id: string, payload: LLMNodeType) => {
   }, [inputs, setInputs])
 
   const filterInputVar = useCallback((varPayload: Var) => {
+    return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.file, VarType.arrayFile].includes(varPayload.type)
+  }, [])
+
+  const filterJinjia2InputVar = useCallback((varPayload: Var) => {
     return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type)
   }, [])
 
   const filterMemoryPromptVar = useCallback((varPayload: Var) => {
-    return [VarType.arrayObject, VarType.array, VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type)
+    return [VarType.arrayObject, VarType.array, VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.file, VarType.arrayFile].includes(varPayload.type)
   }, [])
 
   const {
@@ -406,6 +410,7 @@ const useConfig = (id: string, payload: LLMNodeType) => {
     handleRun,
     handleStop,
     runResult,
+    filterJinjia2InputVar,
   }
 }
 

+ 1 - 0
web/app/components/workflow/nodes/template-transform/panel.tsx

@@ -64,6 +64,7 @@ const Panel: FC<NodePanelProps<TemplateTransformNodeType>> = ({
             onChange={handleVarListChange}
             onVarNameChange={handleVarNameChange}
             filterVar={filterVar}
+            isSupportFileVar={false}
           />
         </Field>
         <Split />