|
@@ -1,29 +1,30 @@
|
|
|
import base64
|
|
|
-import io
|
|
|
import json
|
|
|
+import os
|
|
|
+import tempfile
|
|
|
+import time
|
|
|
from collections.abc import Generator
|
|
|
-from typing import Optional, Union, cast
|
|
|
+from typing import Optional, Union
|
|
|
|
|
|
import google.ai.generativelanguage as glm
|
|
|
import google.generativeai as genai
|
|
|
import requests
|
|
|
from google.api_core import exceptions
|
|
|
-from google.generativeai.client import _ClientManager
|
|
|
-from google.generativeai.types import ContentType, GenerateContentResponse
|
|
|
+from google.generativeai.types import ContentType, File, GenerateContentResponse
|
|
|
from google.generativeai.types.content_types import to_part
|
|
|
-from PIL import Image
|
|
|
|
|
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
|
|
from core.model_runtime.entities.message_entities import (
|
|
|
AssistantPromptMessage,
|
|
|
- DocumentPromptMessageContent,
|
|
|
ImagePromptMessageContent,
|
|
|
PromptMessage,
|
|
|
+ PromptMessageContent,
|
|
|
PromptMessageContentType,
|
|
|
PromptMessageTool,
|
|
|
SystemPromptMessage,
|
|
|
ToolPromptMessage,
|
|
|
UserPromptMessage,
|
|
|
+ VideoPromptMessageContent,
|
|
|
)
|
|
|
from core.model_runtime.errors.invoke import (
|
|
|
InvokeAuthorizationError,
|
|
@@ -35,21 +36,7 @@ from core.model_runtime.errors.invoke import (
|
|
|
)
|
|
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
|
-
|
|
|
-GOOGLE_AVAILABLE_MIMETYPE = [
|
|
|
- "application/pdf",
|
|
|
- "application/x-javascript",
|
|
|
- "text/javascript",
|
|
|
- "application/x-python",
|
|
|
- "text/x-python",
|
|
|
- "text/plain",
|
|
|
- "text/html",
|
|
|
- "text/css",
|
|
|
- "text/md",
|
|
|
- "text/csv",
|
|
|
- "text/xml",
|
|
|
- "text/rtf",
|
|
|
-]
|
|
|
+from extensions.ext_redis import redis_client
|
|
|
|
|
|
|
|
|
class GoogleLargeLanguageModel(LargeLanguageModel):
|
|
@@ -201,29 +188,17 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|
|
if stop:
|
|
|
config_kwargs["stop_sequences"] = stop
|
|
|
|
|
|
+ genai.configure(api_key=credentials["google_api_key"])
|
|
|
google_model = genai.GenerativeModel(model_name=model)
|
|
|
|
|
|
history = []
|
|
|
|
|
|
- # hack for gemini-pro-vision, which currently does not support multi-turn chat
|
|
|
- if model == "gemini-pro-vision":
|
|
|
- last_msg = prompt_messages[-1]
|
|
|
- content = self._format_message_to_glm_content(last_msg)
|
|
|
- history.append(content)
|
|
|
- else:
|
|
|
- for msg in prompt_messages: # makes message roles strictly alternating
|
|
|
- content = self._format_message_to_glm_content(msg)
|
|
|
- if history and history[-1]["role"] == content["role"]:
|
|
|
- history[-1]["parts"].extend(content["parts"])
|
|
|
- else:
|
|
|
- history.append(content)
|
|
|
-
|
|
|
- # Create a new ClientManager with tenant's API key
|
|
|
- new_client_manager = _ClientManager()
|
|
|
- new_client_manager.configure(api_key=credentials["google_api_key"])
|
|
|
- new_custom_client = new_client_manager.make_client("generative")
|
|
|
-
|
|
|
- google_model._client = new_custom_client
|
|
|
+ for msg in prompt_messages: # makes message roles strictly alternating
|
|
|
+ content = self._format_message_to_glm_content(msg)
|
|
|
+ if history and history[-1]["role"] == content["role"]:
|
|
|
+ history[-1]["parts"].extend(content["parts"])
|
|
|
+ else:
|
|
|
+ history.append(content)
|
|
|
|
|
|
response = google_model.generate_content(
|
|
|
contents=history,
|
|
@@ -346,7 +321,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
|
content = message.content
|
|
|
if isinstance(content, list):
|
|
|
- content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE)
|
|
|
+ content = "".join(c.data for c in content if c.type == PromptMessageContentType.TEXT)
|
|
|
|
|
|
if isinstance(message, UserPromptMessage):
|
|
|
message_text = f"{human_prompt} {content}"
|
|
@@ -359,6 +334,44 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
|
return message_text
|
|
|
|
|
|
+ def _upload_file_content_to_google(self, message_content: PromptMessageContent) -> File:
|
|
|
+ key = f"{message_content.type.value}:{hash(message_content.data)}"
|
|
|
+ if redis_client.exists(key):
|
|
|
+ try:
|
|
|
+ return genai.get_file(redis_client.get(key).decode())
|
|
|
+ except:
|
|
|
+ pass
|
|
|
+ with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
|
|
+ if message_content.data.startswith("data:"):
|
|
|
+ metadata, base64_data = message_content.data.split(",", 1)
|
|
|
+ file_content = base64.b64decode(base64_data)
|
|
|
+ mime_type = metadata.split(";", 1)[0].split(":")[1]
|
|
|
+ temp_file.write(file_content)
|
|
|
+ else:
|
|
|
+ # only ImagePromptMessageContent and VideoPromptMessageContent has url
|
|
|
+ try:
|
|
|
+ response = requests.get(message_content.data)
|
|
|
+ response.raise_for_status()
|
|
|
+ if message_content.type is ImagePromptMessageContent:
|
|
|
+ prefix = "image/"
|
|
|
+ elif message_content.type is VideoPromptMessageContent:
|
|
|
+ prefix = "video/"
|
|
|
+ mime_type = prefix + message_content.format
|
|
|
+ temp_file.write(response.content)
|
|
|
+ except Exception as ex:
|
|
|
+ raise ValueError(f"Failed to fetch data from url {message_content.data}, {ex}")
|
|
|
+ temp_file.flush()
|
|
|
+ try:
|
|
|
+ file = genai.upload_file(path=temp_file.name, mime_type=mime_type)
|
|
|
+ while file.state.name == "PROCESSING":
|
|
|
+ time.sleep(5)
|
|
|
+ file = genai.get_file(file.name)
|
|
|
+ # google will delete your upload files in 2 days.
|
|
|
+ redis_client.setex(key, 47 * 60 * 60, file.name)
|
|
|
+ return file
|
|
|
+ finally:
|
|
|
+ os.unlink(temp_file.name)
|
|
|
+
|
|
|
def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType:
|
|
|
"""
|
|
|
Format a single message into glm.Content for Google API
|
|
@@ -374,28 +387,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|
|
for c in message.content:
|
|
|
if c.type == PromptMessageContentType.TEXT:
|
|
|
glm_content["parts"].append(to_part(c.data))
|
|
|
- elif c.type == PromptMessageContentType.IMAGE:
|
|
|
- message_content = cast(ImagePromptMessageContent, c)
|
|
|
- if message_content.data.startswith("data:"):
|
|
|
- metadata, base64_data = c.data.split(",", 1)
|
|
|
- mime_type = metadata.split(";", 1)[0].split(":")[1]
|
|
|
- else:
|
|
|
- # fetch image data from url
|
|
|
- try:
|
|
|
- image_content = requests.get(message_content.data).content
|
|
|
- with Image.open(io.BytesIO(image_content)) as img:
|
|
|
- mime_type = f"image/{img.format.lower()}"
|
|
|
- base64_data = base64.b64encode(image_content).decode("utf-8")
|
|
|
- except Exception as ex:
|
|
|
- raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
|
|
- blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}}
|
|
|
- glm_content["parts"].append(blob)
|
|
|
- elif c.type == PromptMessageContentType.DOCUMENT:
|
|
|
- message_content = cast(DocumentPromptMessageContent, c)
|
|
|
- if message_content.mime_type not in GOOGLE_AVAILABLE_MIMETYPE:
|
|
|
- raise ValueError(f"Unsupported mime type {message_content.mime_type}")
|
|
|
- blob = {"inline_data": {"mime_type": message_content.mime_type, "data": message_content.data}}
|
|
|
- glm_content["parts"].append(blob)
|
|
|
+ else:
|
|
|
+ glm_content["parts"].append(self._upload_file_content_to_google(c))
|
|
|
|
|
|
return glm_content
|
|
|
elif isinstance(message, AssistantPromptMessage):
|