Prechádzať zdrojové kódy

chore: the consistency of MultiModalPromptMessageContent (#11721)

非法操作 4 mesiacov pred
rodič
commit
c9b4029ce7

+ 1 - 2
api/.env.example

@@ -313,8 +313,7 @@ UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
 UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
 
 # Model configuration
-MULTIMODAL_SEND_IMAGE_FORMAT=base64
-MULTIMODAL_SEND_VIDEO_FORMAT=base64
+MULTIMODAL_SEND_FORMAT=base64
 PROMPT_GENERATION_MAX_TOKENS=512
 CODE_GENERATION_MAX_TOKENS=1024
 

+ 4 - 9
api/configs/feature/__init__.py

@@ -665,14 +665,9 @@ class IndexingConfig(BaseSettings):
     )
 
 
-class VisionFormatConfig(BaseSettings):
-    MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field(
-        description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64",
-        default="base64",
-    )
-
-    MULTIMODAL_SEND_VIDEO_FORMAT: Literal["base64", "url"] = Field(
-        description="Format for sending videos in multimodal contexts ('base64' or 'url'), default is base64",
+class MultiModalTransferConfig(BaseSettings):
+    MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field(
+        description="Format for sending files in multimodal contexts ('base64' or 'url'), default is base64",
         default="base64",
     )
 
@@ -778,13 +773,13 @@ class FeatureConfig(
     FileAccessConfig,
     FileUploadConfig,
     HttpConfig,
-    VisionFormatConfig,
     InnerAPIConfig,
     IndexingConfig,
     LoggingConfig,
     MailConfig,
     ModelLoadBalanceConfig,
     ModerationConfig,
+    MultiModalTransferConfig,
     PositionConfig,
     RagEtlConfig,
     SecurityConfig,

+ 25 - 32
api/core/file/file_manager.py

@@ -42,33 +42,31 @@ def to_prompt_message_content(
     *,
     image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
 ):
-    match f.type:
-        case FileType.IMAGE:
-            image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
-            if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
-                data = _to_url(f)
-            else:
-                data = _to_base64_data_string(f)
-
-            return ImagePromptMessageContent(data=data, detail=image_detail_config, format=f.extension.lstrip("."))
-        case FileType.AUDIO:
-            data = _to_base64_data_string(f)
-            if f.extension is None:
-                raise ValueError("Missing file extension")
-            return AudioPromptMessageContent(data=data, format=f.extension.lstrip("."))
-        case FileType.VIDEO:
-            if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url":
-                data = _to_url(f)
-            else:
-                data = _to_base64_data_string(f)
-            if f.extension is None:
-                raise ValueError("Missing file extension")
-            return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
-        case FileType.DOCUMENT:
-            data = _to_base64_data_string(f)
-            return DocumentPromptMessageContent(encode_format="base64", data=data, format=f.extension.lstrip("."))
-        case _:
-            raise ValueError(f"file type {f.type} is not supported")
+    if f.extension is None:
+        raise ValueError("Missing file extension")
+    if f.mime_type is None:
+        raise ValueError("Missing file mime_type")
+
+    params = {
+        "base64_data": _get_encoded_string(f) if dify_config.MULTIMODAL_SEND_FORMAT == "base64" else "",
+        "url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "",
+        "format": f.extension.removeprefix("."),
+        "mime_type": f.mime_type,
+    }
+    if f.type == FileType.IMAGE:
+        params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
+
+    prompt_class_map = {
+        FileType.IMAGE: ImagePromptMessageContent,
+        FileType.AUDIO: AudioPromptMessageContent,
+        FileType.VIDEO: VideoPromptMessageContent,
+        FileType.DOCUMENT: DocumentPromptMessageContent,
+    }
+
+    try:
+        return prompt_class_map[f.type](**params)
+    except KeyError:
+        raise ValueError(f"file type {f.type} is not supported")
 
 
 def download(f: File, /):
@@ -122,11 +120,6 @@ def _get_encoded_string(f: File, /):
     return encoded_string
 
 
-def _to_base64_data_string(f: File, /):
-    encoded_string = _get_encoded_string(f)
-    return f"data:{f.mime_type};base64,{encoded_string}"
-
-
 def _to_url(f: File, /):
     if f.transfer_method == FileTransferMethod.REMOTE_URL:
         if f.remote_url is None:

+ 24 - 15
api/core/model_runtime/entities/message_entities.py

@@ -1,9 +1,9 @@
 from abc import ABC
 from collections.abc import Sequence
 from enum import Enum, StrEnum
-from typing import Literal, Optional
+from typing import Optional
 
-from pydantic import BaseModel, Field, field_validator
+from pydantic import BaseModel, Field, computed_field, field_validator
 
 
 class PromptMessageRole(Enum):
@@ -67,7 +67,6 @@ class PromptMessageContent(BaseModel):
     """
 
     type: PromptMessageContentType
-    data: str
 
 
 class TextPromptMessageContent(PromptMessageContent):
@@ -76,21 +75,35 @@ class TextPromptMessageContent(PromptMessageContent):
     """
 
     type: PromptMessageContentType = PromptMessageContentType.TEXT
+    data: str
+
+
+class MultiModalPromptMessageContent(PromptMessageContent):
+    """
+    Model class for multi-modal prompt message content.
+    """
+
+    type: PromptMessageContentType
+    format: str = Field(..., description="the format of multi-modal file")
+    base64_data: str = Field("", description="the base64 data of multi-modal file")
+    url: str = Field("", description="the url of multi-modal file")
+    mime_type: str = Field(..., description="the mime type of multi-modal file")
 
+    @computed_field(return_type=str)
+    @property
+    def data(self):
+        return self.url or f"data:{self.mime_type};base64,{self.base64_data}"
 
-class VideoPromptMessageContent(PromptMessageContent):
+
+class VideoPromptMessageContent(MultiModalPromptMessageContent):
     type: PromptMessageContentType = PromptMessageContentType.VIDEO
-    data: str = Field(..., description="Base64 encoded video data")
-    format: str = Field(..., description="Video format")
 
 
-class AudioPromptMessageContent(PromptMessageContent):
+class AudioPromptMessageContent(MultiModalPromptMessageContent):
     type: PromptMessageContentType = PromptMessageContentType.AUDIO
-    data: str = Field(..., description="Base64 encoded audio data")
-    format: str = Field(..., description="Audio format")
 
 
-class ImagePromptMessageContent(PromptMessageContent):
+class ImagePromptMessageContent(MultiModalPromptMessageContent):
     """
     Model class for image prompt message content.
     """
@@ -101,14 +114,10 @@ class ImagePromptMessageContent(PromptMessageContent):
 
     type: PromptMessageContentType = PromptMessageContentType.IMAGE
     detail: DETAIL = DETAIL.LOW
-    format: str = Field("jpg", description="Image format")
 
 
-class DocumentPromptMessageContent(PromptMessageContent):
+class DocumentPromptMessageContent(MultiModalPromptMessageContent):
     type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
-    encode_format: Literal["base64"]
-    data: str
-    format: str = Field(..., description="Document format")
 
 
 class PromptMessage(ABC, BaseModel):

+ 10 - 17
api/core/model_runtime/model_providers/anthropic/llm/llm.py

@@ -1,5 +1,4 @@
 import base64
-import io
 import json
 from collections.abc import Generator, Sequence
 from typing import Optional, Union, cast
@@ -18,7 +17,6 @@ from anthropic.types import (
 )
 from anthropic.types.beta.tools import ToolsBetaMessage
 from httpx import Timeout
-from PIL import Image
 
 from core.model_runtime.callbacks.base_callback import Callback
 from core.model_runtime.entities import (
@@ -498,22 +496,19 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
                                 sub_messages.append(sub_message_dict)
                             elif message_content.type == PromptMessageContentType.IMAGE:
                                 message_content = cast(ImagePromptMessageContent, message_content)
-                                if not message_content.data.startswith("data:"):
+                                if not message_content.base64_data:
                                     # fetch image data from url
                                     try:
-                                        image_content = requests.get(message_content.data).content
-                                        with Image.open(io.BytesIO(image_content)) as img:
-                                            mime_type = f"image/{img.format.lower()}"
+                                        image_content = requests.get(message_content.url).content
                                         base64_data = base64.b64encode(image_content).decode("utf-8")
                                     except Exception as ex:
                                         raise ValueError(
                                             f"Failed to fetch image data from url {message_content.data}, {ex}"
                                         )
                                 else:
-                                    data_split = message_content.data.split(";base64,")
-                                    mime_type = data_split[0].replace("data:", "")
-                                    base64_data = data_split[1]
+                                    base64_data = message_content.base64_data
 
+                                mime_type = message_content.mime_type
                                 if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
                                     raise ValueError(
                                         f"Unsupported image type {mime_type}, "
@@ -526,19 +521,17 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
                                 }
                                 sub_messages.append(sub_message_dict)
                             elif isinstance(message_content, DocumentPromptMessageContent):
-                                data_split = message_content.data.split(";base64,")
-                                mime_type = data_split[0].replace("data:", "")
-                                base64_data = data_split[1]
-                                if mime_type != "application/pdf":
+                                if message_content.mime_type != "application/pdf":
                                     raise ValueError(
-                                        f"Unsupported document type {mime_type}, " "only support application/pdf"
+                                        f"Unsupported document type {message_content.mime_type}, "
+                                        "only support application/pdf"
                                     )
                                 sub_message_dict = {
                                     "type": "document",
                                     "source": {
-                                        "type": message_content.encode_format,
-                                        "media_type": mime_type,
-                                        "data": base64_data,
+                                        "type": "base64",
+                                        "media_type": message_content.mime_type,
+                                        "data": message_content.data,
                                     },
                                 }
                                 sub_messages.append(sub_message_dict)

+ 3 - 3
api/core/model_runtime/model_providers/tongyi/llm/llm.py

@@ -434,9 +434,9 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
                             sub_messages.append(sub_message_dict)
                         elif message_content.type == PromptMessageContentType.VIDEO:
                             message_content = cast(VideoPromptMessageContent, message_content)
-                            video_url = message_content.data
-                            if message_content.data.startswith("data:"):
-                                raise InvokeError("not support base64, please set MULTIMODAL_SEND_VIDEO_FORMAT to url")
+                            video_url = message_content.url
+                            if not video_url:
+                                raise InvokeError("not support base64, please set MULTIMODAL_SEND_FORMAT to url")
 
                             sub_message_dict = {"video": video_url}
                             sub_messages.append(sub_message_dict)

Rozdielové dáta súboru neboli zobrazené, pretože súbor je príliš veľký
+ 0 - 0
api/tests/integration_tests/model_runtime/azure_openai/test_llm.py


Rozdielové dáta súboru neboli zobrazené, pretože súbor je príliš veľký
+ 0 - 0
api/tests/integration_tests/model_runtime/google/test_llm.py


Rozdielové dáta súboru neboli zobrazené, pretože súbor je príliš veľký
+ 0 - 0
api/tests/integration_tests/model_runtime/ollama/test_llm.py


Rozdielové dáta súboru neboli zobrazené, pretože súbor je príliš veľký
+ 0 - 0
api/tests/integration_tests/model_runtime/openai/test_llm.py


+ 5 - 1
api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py

@@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
 
 import pytest
 
+from configs import dify_config
 from core.app.app_config.entities import ModelConfigEntity
 from core.file import File, FileTransferMethod, FileType, FileUploadConfig, ImageConfig
 from core.memory.token_buffer_memory import TokenBufferMemory
@@ -126,6 +127,7 @@ def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args):
 
 def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args):
     model_config_mock, _, messages, inputs, context = get_chat_model_args
+    dify_config.MULTIMODAL_SEND_FORMAT = "url"
 
     files = [
         File(
@@ -140,7 +142,9 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
     prompt_transform = AdvancedPromptTransform()
     prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
     with patch("core.file.file_manager.to_prompt_message_content") as mock_get_encoded_string:
-        mock_get_encoded_string.return_value = ImagePromptMessageContent(data=str(files[0].remote_url))
+        mock_get_encoded_string.return_value = ImagePromptMessageContent(
+            url=str(files[0].remote_url), format="jpg", mime_type="image/jpg"
+        )
         prompt_messages = prompt_transform._get_chat_model_prompt_messages(
             prompt_template=messages,
             inputs=inputs,

+ 12 - 8
api/tests/unit_tests/core/workflow/nodes/llm/test_node.py

@@ -18,8 +18,7 @@ from core.model_runtime.entities.message_entities import (
     TextPromptMessageContent,
     UserPromptMessage,
 )
-from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel
-from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType
 from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
 from core.prompt.entities.advanced_prompt_entities import MemoryConfig
 from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
@@ -249,8 +248,7 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
 
 def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
     # Setup dify config
-    dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url"
-    dify_config.MULTIMODAL_SEND_VIDEO_FORMAT = "url"
+    dify_config.MULTIMODAL_SEND_FORMAT = "url"
 
     # Generate fake values for prompt template
     fake_assistant_prompt = faker.sentence()
@@ -326,9 +324,10 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
                     tenant_id="test",
                     type=FileType.IMAGE,
                     filename="test1.jpg",
-                    extension=".jpg",
                     transfer_method=FileTransferMethod.REMOTE_URL,
                     remote_url=fake_remote_url,
+                    extension=".jpg",
+                    mime_type="image/jpg",
                 )
             ],
             vision_enabled=True,
@@ -362,7 +361,9 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
                 UserPromptMessage(
                     content=[
                         TextPromptMessageContent(data=fake_query),
-                        ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
+                        ImagePromptMessageContent(
+                            url=fake_remote_url, mime_type="image/jpg", format="jpg", detail=fake_vision_detail
+                        ),
                     ]
                 ),
             ],
@@ -385,7 +386,9 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
             expected_messages=[
                 UserPromptMessage(
                     content=[
-                        ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
+                        ImagePromptMessageContent(
+                            url=fake_remote_url, mime_type="image/jpg", format="jpg", detail=fake_vision_detail
+                        ),
                     ]
                 ),
             ]
@@ -396,9 +399,10 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
                     tenant_id="test",
                     type=FileType.IMAGE,
                     filename="test1.jpg",
-                    extension=".jpg",
                     transfer_method=FileTransferMethod.REMOTE_URL,
                     remote_url=fake_remote_url,
+                    extension=".jpg",
+                    mime_type="image/jpg",
                 )
             },
         ),

+ 3 - 4
docker/.env.example

@@ -614,13 +614,12 @@ CODE_GENERATION_MAX_TOKENS=1024
 # Multi-modal Configuration
 # ------------------------------
 
-# The format of the image/video sent when the multi-modal model is input,
+# The format of the image/video/audio/document sent when the multi-modal model is input,
 # the default is base64, optional url.
 # The delay of the call in url mode will be lower than that in base64 mode.
 # It is generally recommended to use the more compatible base64 mode.
-# If configured as url, you need to configure FILES_URL as an externally accessible address so that the multi-modal model can access the image/video.
-MULTIMODAL_SEND_IMAGE_FORMAT=base64
-MULTIMODAL_SEND_VIDEO_FORMAT=base64
+# If configured as url, you need to configure FILES_URL as an externally accessible address so that the multi-modal model can access the image/video/audio/document.
+MULTIMODAL_SEND_FORMAT=base64
 
 # Upload image file size limit, default 10M.
 UPLOAD_IMAGE_FILE_SIZE_LIMIT=10

+ 1 - 2
docker/docker-compose.yaml

@@ -225,8 +225,7 @@ x-shared-env: &shared-api-worker-env
   UNSTRUCTURED_API_KEY: ${UNSTRUCTURED_API_KEY:-}
   PROMPT_GENERATION_MAX_TOKENS: ${PROMPT_GENERATION_MAX_TOKENS:-512}
   CODE_GENERATION_MAX_TOKENS: ${CODE_GENERATION_MAX_TOKENS:-1024}
-  MULTIMODAL_SEND_IMAGE_FORMAT: ${MULTIMODAL_SEND_IMAGE_FORMAT:-base64}
-  MULTIMODAL_SEND_VIDEO_FORMAT: ${MULTIMODAL_SEND_VIDEO_FORMAT:-base64}
+  MULTIMODAL_SEND_FORMAT: ${MULTIMODAL_SEND_FORMAT:-base64}
   UPLOAD_IMAGE_FILE_SIZE_LIMIT: ${UPLOAD_IMAGE_FILE_SIZE_LIMIT:-10}
   UPLOAD_VIDEO_FILE_SIZE_LIMIT: ${UPLOAD_VIDEO_FILE_SIZE_LIMIT:-100}
   UPLOAD_AUDIO_FILE_SIZE_LIMIT: ${UPLOAD_AUDIO_FILE_SIZE_LIMIT:-50}

Niektoré súbory nie sú zobrazené, pretože je v týchto rozdielových dátach zmenené mnoho súborov