ソースを参照

feat: enhance gemini models (#11497)

非法操作 4 ヶ月 前
コミット
74fdc16bd1
23 ファイル変更138 行追加113 行削除
  1. 5 11
      api/core/file/file_manager.py
  2. 2 1
      api/core/model_runtime/entities/message_entities.py
  3. 7 5
      api/core/model_runtime/model_providers/anthropic/llm/llm.py
  4. 2 0
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-001.yaml
  5. 2 0
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-002.yaml
  6. 2 0
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0827.yaml
  7. 2 0
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0924.yaml
  8. 2 0
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-exp-0827.yaml
  9. 2 0
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-latest.yaml
  10. 2 0
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash.yaml
  11. 2 0
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-001.yaml
  12. 2 0
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-002.yaml
  13. 2 0
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0801.yaml
  14. 2 0
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0827.yaml
  15. 2 0
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-latest.yaml
  16. 2 0
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro.yaml
  17. 2 0
      api/core/model_runtime/model_providers/google/llm/gemini-exp-1114.yaml
  18. 3 0
      api/core/model_runtime/model_providers/google/llm/gemini-exp-1121.yaml
  19. 56 63
      api/core/model_runtime/model_providers/google/llm/llm.py
  20. 3 1
      api/core/model_runtime/model_providers/openai/llm/llm.py
  21. 29 29
      api/tests/integration_tests/model_runtime/__mock/google.py
  22. 3 3
      api/tests/integration_tests/model_runtime/google/test_llm.py
  23. 2 0
      api/tests/unit_tests/core/workflow/nodes/llm/test_node.py

+ 5 - 11
api/core/file/file_manager.py

@@ -50,12 +50,12 @@ def to_prompt_message_content(
             else:
                 data = _to_base64_data_string(f)
 
-            return ImagePromptMessageContent(data=data, detail=image_detail_config)
+            return ImagePromptMessageContent(data=data, detail=image_detail_config, format=f.extension.lstrip("."))
         case FileType.AUDIO:
-            encoded_string = _get_encoded_string(f)
+            data = _to_base64_data_string(f)
             if f.extension is None:
                 raise ValueError("Missing file extension")
-            return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
+            return AudioPromptMessageContent(data=data, format=f.extension.lstrip("."))
         case FileType.VIDEO:
             if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url":
                 data = _to_url(f)
@@ -65,14 +65,8 @@ def to_prompt_message_content(
                 raise ValueError("Missing file extension")
             return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
         case FileType.DOCUMENT:
-            data = _get_encoded_string(f)
-            if f.mime_type is None:
-                raise ValueError("Missing file mime_type")
-            return DocumentPromptMessageContent(
-                encode_format="base64",
-                mime_type=f.mime_type,
-                data=data,
-            )
+            data = _to_base64_data_string(f)
+            return DocumentPromptMessageContent(encode_format="base64", data=data, format=f.extension.lstrip("."))
         case _:
             raise ValueError(f"file type {f.type} is not supported")
 

+ 2 - 1
api/core/model_runtime/entities/message_entities.py

@@ -101,13 +101,14 @@ class ImagePromptMessageContent(PromptMessageContent):
 
     type: PromptMessageContentType = PromptMessageContentType.IMAGE
     detail: DETAIL = DETAIL.LOW
+    format: str = Field("jpg", description="Image format")
 
 
 class DocumentPromptMessageContent(PromptMessageContent):
     type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
     encode_format: Literal["base64"]
-    mime_type: str
     data: str
+    format: str = Field(..., description="Document format")
 
 
 class PromptMessage(ABC, BaseModel):

+ 7 - 5
api/core/model_runtime/model_providers/anthropic/llm/llm.py

@@ -526,17 +526,19 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
                                 }
                                 sub_messages.append(sub_message_dict)
                             elif isinstance(message_content, DocumentPromptMessageContent):
-                                if message_content.mime_type != "application/pdf":
+                                data_split = message_content.data.split(";base64,")
+                                mime_type = data_split[0].replace("data:", "")
+                                base64_data = data_split[1]
+                                if mime_type != "application/pdf":
                                     raise ValueError(
-                                        f"Unsupported document type {message_content.mime_type}, "
-                                        "only support application/pdf"
+                                        f"Unsupported document type {mime_type}, " "only support application/pdf"
                                     )
                                 sub_message_dict = {
                                     "type": "document",
                                     "source": {
                                         "type": message_content.encode_format,
-                                        "media_type": message_content.mime_type,
-                                        "data": message_content.data,
+                                        "media_type": mime_type,
+                                        "data": base64_data,
                                     },
                                 }
                                 sub_messages.append(sub_message_dict)

+ 2 - 0
api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-001.yaml

@@ -8,6 +8,8 @@ features:
   - tool-call
   - stream-tool-call
   - document
+  - video
+  - audio
 model_properties:
   mode: chat
   context_size: 1048576

+ 2 - 0
api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-002.yaml

@@ -8,6 +8,8 @@ features:
   - tool-call
   - stream-tool-call
   - document
+  - video
+  - audio
 model_properties:
   mode: chat
   context_size: 1048576

+ 2 - 0
api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0827.yaml

@@ -8,6 +8,8 @@ features:
   - tool-call
   - stream-tool-call
   - document
+  - video
+  - audio
 model_properties:
   mode: chat
   context_size: 1048576

+ 2 - 0
api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0924.yaml

@@ -8,6 +8,8 @@ features:
   - tool-call
   - stream-tool-call
   - document
+  - video
+  - audio
 model_properties:
   mode: chat
   context_size: 1048576

+ 2 - 0
api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-exp-0827.yaml

@@ -8,6 +8,8 @@ features:
   - tool-call
   - stream-tool-call
   - document
+  - video
+  - audio
 model_properties:
   mode: chat
   context_size: 1048576

+ 2 - 0
api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-latest.yaml

@@ -8,6 +8,8 @@ features:
   - tool-call
   - stream-tool-call
   - document
+  - video
+  - audio
 model_properties:
   mode: chat
   context_size: 1048576

+ 2 - 0
api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash.yaml

@@ -8,6 +8,8 @@ features:
   - tool-call
   - stream-tool-call
   - document
+  - video
+  - audio
 model_properties:
   mode: chat
   context_size: 1048576

+ 2 - 0
api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-001.yaml

@@ -8,6 +8,8 @@ features:
   - tool-call
   - stream-tool-call
   - document
+  - video
+  - audio
 model_properties:
   mode: chat
   context_size: 2097152

+ 2 - 0
api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-002.yaml

@@ -8,6 +8,8 @@ features:
   - tool-call
   - stream-tool-call
   - document
+  - video
+  - audio
 model_properties:
   mode: chat
   context_size: 2097152

+ 2 - 0
api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0801.yaml

@@ -8,6 +8,8 @@ features:
   - tool-call
   - stream-tool-call
   - document
+  - video
+  - audio
 model_properties:
   mode: chat
   context_size: 2097152

+ 2 - 0
api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0827.yaml

@@ -8,6 +8,8 @@ features:
   - tool-call
   - stream-tool-call
   - document
+  - video
+  - audio
 model_properties:
   mode: chat
   context_size: 2097152

+ 2 - 0
api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-latest.yaml

@@ -8,6 +8,8 @@ features:
   - tool-call
   - stream-tool-call
   - document
+  - video
+  - audio
 model_properties:
   mode: chat
   context_size: 2097152

+ 2 - 0
api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro.yaml

@@ -8,6 +8,8 @@ features:
   - tool-call
   - stream-tool-call
   - document
+  - video
+  - audio
 model_properties:
   mode: chat
   context_size: 2097152

+ 2 - 0
api/core/model_runtime/model_providers/google/llm/gemini-exp-1114.yaml

@@ -8,6 +8,8 @@ features:
   - tool-call
   - stream-tool-call
   - document
+  - video
+  - audio
 model_properties:
   mode: chat
   context_size: 32767

+ 3 - 0
api/core/model_runtime/model_providers/google/llm/gemini-exp-1121.yaml

@@ -7,6 +7,9 @@ features:
   - vision
   - tool-call
   - stream-tool-call
+  - document
+  - video
+  - audio
 model_properties:
   mode: chat
   context_size: 32767

+ 56 - 63
api/core/model_runtime/model_providers/google/llm/llm.py

@@ -1,29 +1,30 @@
 import base64
-import io
 import json
+import os
+import tempfile
+import time
 from collections.abc import Generator
-from typing import Optional, Union, cast
+from typing import Optional, Union
 
 import google.ai.generativelanguage as glm
 import google.generativeai as genai
 import requests
 from google.api_core import exceptions
-from google.generativeai.client import _ClientManager
-from google.generativeai.types import ContentType, GenerateContentResponse
+from google.generativeai.types import ContentType, File, GenerateContentResponse
 from google.generativeai.types.content_types import to_part
-from PIL import Image
 
 from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
 from core.model_runtime.entities.message_entities import (
     AssistantPromptMessage,
-    DocumentPromptMessageContent,
     ImagePromptMessageContent,
     PromptMessage,
+    PromptMessageContent,
     PromptMessageContentType,
     PromptMessageTool,
     SystemPromptMessage,
     ToolPromptMessage,
     UserPromptMessage,
+    VideoPromptMessageContent,
 )
 from core.model_runtime.errors.invoke import (
     InvokeAuthorizationError,
@@ -35,21 +36,7 @@ from core.model_runtime.errors.invoke import (
 )
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
-
-GOOGLE_AVAILABLE_MIMETYPE = [
-    "application/pdf",
-    "application/x-javascript",
-    "text/javascript",
-    "application/x-python",
-    "text/x-python",
-    "text/plain",
-    "text/html",
-    "text/css",
-    "text/md",
-    "text/csv",
-    "text/xml",
-    "text/rtf",
-]
+from extensions.ext_redis import redis_client
 
 
 class GoogleLargeLanguageModel(LargeLanguageModel):
@@ -201,29 +188,17 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
         if stop:
             config_kwargs["stop_sequences"] = stop
 
+        genai.configure(api_key=credentials["google_api_key"])
         google_model = genai.GenerativeModel(model_name=model)
 
         history = []
 
-        # hack for gemini-pro-vision, which currently does not support multi-turn chat
-        if model == "gemini-pro-vision":
-            last_msg = prompt_messages[-1]
-            content = self._format_message_to_glm_content(last_msg)
-            history.append(content)
-        else:
-            for msg in prompt_messages:  # makes message roles strictly alternating
-                content = self._format_message_to_glm_content(msg)
-                if history and history[-1]["role"] == content["role"]:
-                    history[-1]["parts"].extend(content["parts"])
-                else:
-                    history.append(content)
-
-        # Create a new ClientManager with tenant's API key
-        new_client_manager = _ClientManager()
-        new_client_manager.configure(api_key=credentials["google_api_key"])
-        new_custom_client = new_client_manager.make_client("generative")
-
-        google_model._client = new_custom_client
+        for msg in prompt_messages:  # makes message roles strictly alternating
+            content = self._format_message_to_glm_content(msg)
+            if history and history[-1]["role"] == content["role"]:
+                history[-1]["parts"].extend(content["parts"])
+            else:
+                history.append(content)
 
         response = google_model.generate_content(
             contents=history,
@@ -346,7 +321,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
 
         content = message.content
         if isinstance(content, list):
-            content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE)
+            content = "".join(c.data for c in content if c.type == PromptMessageContentType.TEXT)
 
         if isinstance(message, UserPromptMessage):
             message_text = f"{human_prompt} {content}"
@@ -359,6 +334,44 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
 
         return message_text
 
+    def _upload_file_content_to_google(self, message_content: PromptMessageContent) -> File:
+        key = f"{message_content.type.value}:{hash(message_content.data)}"
+        if redis_client.exists(key):
+            try:
+                return genai.get_file(redis_client.get(key).decode())
+            except:
+                pass
+        with tempfile.NamedTemporaryFile(delete=False) as temp_file:
+            if message_content.data.startswith("data:"):
+                metadata, base64_data = message_content.data.split(",", 1)
+                file_content = base64.b64decode(base64_data)
+                mime_type = metadata.split(";", 1)[0].split(":")[1]
+                temp_file.write(file_content)
+            else:
+                # only ImagePromptMessageContent and VideoPromptMessageContent has url
+                try:
+                    response = requests.get(message_content.data)
+                    response.raise_for_status()
+                    if message_content.type is ImagePromptMessageContent:
+                        prefix = "image/"
+                    elif message_content.type is VideoPromptMessageContent:
+                        prefix = "video/"
+                    mime_type = prefix + message_content.format
+                    temp_file.write(response.content)
+                except Exception as ex:
+                    raise ValueError(f"Failed to fetch data from url {message_content.data}, {ex}")
+            temp_file.flush()
+        try:
+            file = genai.upload_file(path=temp_file.name, mime_type=mime_type)
+            while file.state.name == "PROCESSING":
+                time.sleep(5)
+                file = genai.get_file(file.name)
+            # google will delete your upload files in 2 days.
+            redis_client.setex(key, 47 * 60 * 60, file.name)
+            return file
+        finally:
+            os.unlink(temp_file.name)
+
     def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType:
         """
         Format a single message into glm.Content for Google API
@@ -374,28 +387,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
                 for c in message.content:
                     if c.type == PromptMessageContentType.TEXT:
                         glm_content["parts"].append(to_part(c.data))
-                    elif c.type == PromptMessageContentType.IMAGE:
-                        message_content = cast(ImagePromptMessageContent, c)
-                        if message_content.data.startswith("data:"):
-                            metadata, base64_data = c.data.split(",", 1)
-                            mime_type = metadata.split(";", 1)[0].split(":")[1]
-                        else:
-                            # fetch image data from url
-                            try:
-                                image_content = requests.get(message_content.data).content
-                                with Image.open(io.BytesIO(image_content)) as img:
-                                    mime_type = f"image/{img.format.lower()}"
-                                base64_data = base64.b64encode(image_content).decode("utf-8")
-                            except Exception as ex:
-                                raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
-                        blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}}
-                        glm_content["parts"].append(blob)
-                    elif c.type == PromptMessageContentType.DOCUMENT:
-                        message_content = cast(DocumentPromptMessageContent, c)
-                        if message_content.mime_type not in GOOGLE_AVAILABLE_MIMETYPE:
-                            raise ValueError(f"Unsupported mime type {message_content.mime_type}")
-                        blob = {"inline_data": {"mime_type": message_content.mime_type, "data": message_content.data}}
-                        glm_content["parts"].append(blob)
+                    else:
+                        glm_content["parts"].append(self._upload_file_content_to_google(c))
 
             return glm_content
         elif isinstance(message, AssistantPromptMessage):

+ 3 - 1
api/core/model_runtime/model_providers/openai/llm/llm.py

@@ -920,10 +920,12 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
                         }
                         sub_messages.append(sub_message_dict)
                     elif isinstance(message_content, AudioPromptMessageContent):
+                        data_split = message_content.data.split(";base64,")
+                        base64_data = data_split[1]
                         sub_message_dict = {
                             "type": "input_audio",
                             "input_audio": {
-                                "data": message_content.data,
+                                "data": base64_data,
                                 "format": message_content.format,
                             },
                         }

+ 29 - 29
api/tests/integration_tests/model_runtime/__mock/google.py

@@ -1,4 +1,5 @@
 from collections.abc import Generator
+from unittest.mock import MagicMock
 
 import google.generativeai.types.generation_types as generation_config_types
 import pytest
@@ -6,11 +7,10 @@ from _pytest.monkeypatch import MonkeyPatch
 from google.ai import generativelanguage as glm
 from google.ai.generativelanguage_v1beta.types import content as gag_content
 from google.generativeai import GenerativeModel
-from google.generativeai.client import _ClientManager, configure
 from google.generativeai.types import GenerateContentResponse, content_types, safety_types
 from google.generativeai.types.generation_types import BaseGenerateContentResponse
 
-current_api_key = ""
+from extensions import ext_redis
 
 
 class MockGoogleResponseClass:
@@ -57,11 +57,6 @@ class MockGoogleClass:
         stream: bool = False,
         **kwargs,
     ) -> GenerateContentResponse:
-        global current_api_key
-
-        if len(current_api_key) < 16:
-            raise Exception("Invalid API key")
-
         if stream:
             return MockGoogleClass.generate_content_stream()
 
@@ -75,33 +70,29 @@ class MockGoogleClass:
     def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]:
         return [MockGoogleResponseCandidateClass()]
 
-    def make_client(self: _ClientManager, name: str):
-        global current_api_key
 
-        if name.endswith("_async"):
-            name = name.split("_")[0]
-            cls = getattr(glm, name.title() + "ServiceAsyncClient")
-        else:
-            cls = getattr(glm, name.title() + "ServiceClient")
+def mock_configure(api_key: str):
+    if len(api_key) < 16:
+        raise Exception("Invalid API key")
+
+
+class MockFileState:
+    def __init__(self):
+        self.name = "FINISHED"
 
-        # Attempt to configure using defaults.
-        if not self.client_config:
-            configure()
 
-        client_options = self.client_config.get("client_options", None)
-        if client_options:
-            current_api_key = client_options.api_key
+class MockGoogleFile:
+    def __init__(self, name: str = "mock_file_name"):
+        self.name = name
+        self.state = MockFileState()
 
-        def nop(self, *args, **kwargs):
-            pass
 
-        original_init = cls.__init__
-        cls.__init__ = nop
-        client: glm.GenerativeServiceClient = cls(**self.client_config)
-        cls.__init__ = original_init
+def mock_get_file(name: str) -> MockGoogleFile:
+    return MockGoogleFile(name)
 
-        if not self.default_metadata:
-            return client
+
+def mock_upload_file(path: str, mime_type: str) -> MockGoogleFile:
+    return MockGoogleFile()
 
 
 @pytest.fixture
@@ -109,8 +100,17 @@ def setup_google_mock(request, monkeypatch: MonkeyPatch):
     monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text)
     monkeypatch.setattr(BaseGenerateContentResponse, "candidates", MockGoogleClass.generative_response_candidates)
     monkeypatch.setattr(GenerativeModel, "generate_content", MockGoogleClass.generate_content)
-    monkeypatch.setattr(_ClientManager, "make_client", MockGoogleClass.make_client)
+    monkeypatch.setattr("google.generativeai.configure", mock_configure)
+    monkeypatch.setattr("google.generativeai.get_file", mock_get_file)
+    monkeypatch.setattr("google.generativeai.upload_file", mock_upload_file)
 
     yield
 
     monkeypatch.undo()
+
+
+@pytest.fixture
+def setup_mock_redis() -> None:
+    ext_redis.redis_client.get = MagicMock(return_value=None)
+    ext_redis.redis_client.setex = MagicMock(return_value=None)
+    ext_redis.redis_client.exists = MagicMock(return_value=True)

+ 3 - 3
api/tests/integration_tests/model_runtime/google/test_llm.py

@@ -13,7 +13,7 @@ from core.model_runtime.entities.message_entities import (
 )
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.model_providers.google.llm.llm import GoogleLargeLanguageModel
-from tests.integration_tests.model_runtime.__mock.google import setup_google_mock
+from tests.integration_tests.model_runtime.__mock.google import setup_google_mock, setup_mock_redis
 
 
 @pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True)
@@ -95,7 +95,7 @@ def test_invoke_stream_model(setup_google_mock):
 
 
 @pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True)
-def test_invoke_chat_model_with_vision(setup_google_mock):
+def test_invoke_chat_model_with_vision(setup_google_mock, setup_mock_redis):
     model = GoogleLargeLanguageModel()
 
     result = model.invoke(
@@ -124,7 +124,7 @@ def test_invoke_chat_model_with_vision(setup_google_mock):
 
 
 @pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True)
-def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock):
+def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock, setup_mock_redis):
     model = GoogleLargeLanguageModel()
 
     result = model.invoke(

+ 2 - 0
api/tests/unit_tests/core/workflow/nodes/llm/test_node.py

@@ -326,6 +326,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
                     tenant_id="test",
                     type=FileType.IMAGE,
                     filename="test1.jpg",
+                    extension=".jpg",
                     transfer_method=FileTransferMethod.REMOTE_URL,
                     remote_url=fake_remote_url,
                 )
@@ -395,6 +396,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
                     tenant_id="test",
                     type=FileType.IMAGE,
                     filename="test1.jpg",
+                    extension=".jpg",
                     transfer_method=FileTransferMethod.REMOTE_URL,
                     remote_url=fake_remote_url,
                 )