Ver código fonte

Support for Vertex AI (#4586)

Patryk Garstecki 11 meses atrás
pai
commit
296887754f

+ 1 - 0
api/core/model_runtime/model_providers/_position.yaml

@@ -2,6 +2,7 @@
 - anthropic
 - azure_openai
 - google
+- vertex_ai
 - nvidia
 - cohere
 - bedrock

+ 0 - 0
api/core/model_runtime/model_providers/vertex_ai/__init__.py


BIN
api/core/model_runtime/model_providers/vertex_ai/_assets/icon_l_en.png


+ 1 - 0
api/core/model_runtime/model_providers/vertex_ai/_assets/icon_s_en.svg

@@ -0,0 +1 @@
+<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" width="24px" height="24px"><path d="M20,13.89A.77.77,0,0,0,19,13.73l-7,5.14v.22a.72.72,0,1,1,0,1.43v0a.74.74,0,0,0,.45-.15l7.41-5.47A.76.76,0,0,0,20,13.89Z" style="fill:#669df6"/><path d="M12,20.52a.72.72,0,0,1,0-1.43h0v-.22L5,13.73a.76.76,0,0,0-1,.16.74.74,0,0,0,.16,1l7.41,5.47a.73.73,0,0,0,.44.15v0Z" style="fill:#aecbfa"/><path d="M12,18.34a1.47,1.47,0,1,0,1.47,1.47A1.47,1.47,0,0,0,12,18.34Zm0,2.18a.72.72,0,1,1,.72-.71A.71.71,0,0,1,12,20.52Z" style="fill:#4285f4"/><path d="M6,6.11a.76.76,0,0,1-.75-.75V3.48a.76.76,0,1,1,1.51,0V5.36A.76.76,0,0,1,6,6.11Z" style="fill:#aecbfa"/><circle cx="5.98" cy="12" r="0.76" style="fill:#aecbfa"/><circle cx="5.98" cy="9.79" r="0.76" style="fill:#aecbfa"/><circle cx="5.98" cy="7.57" r="0.76" style="fill:#aecbfa"/><path d="M18,8.31a.76.76,0,0,1-.75-.76V5.67a.75.75,0,1,1,1.5,0V7.55A.75.75,0,0,1,18,8.31Z" style="fill:#4285f4"/><circle cx="18.02" cy="12.01" r="0.76" style="fill:#4285f4"/><circle cx="18.02" cy="9.76" r="0.76" style="fill:#4285f4"/><circle cx="18.02" cy="3.48" r="0.76" style="fill:#4285f4"/><path d="M12,15a.76.76,0,0,1-.75-.75V12.34a.76.76,0,0,1,1.51,0v1.89A.76.76,0,0,1,12,15Z" style="fill:#669df6"/><circle cx="12" cy="16.45" r="0.76" style="fill:#669df6"/><circle cx="12" cy="10.14" r="0.76" style="fill:#669df6"/><circle cx="12" cy="7.92" r="0.76" style="fill:#669df6"/><path d="M15,10.54a.76.76,0,0,1-.75-.75V7.91a.76.76,0,1,1,1.51,0V9.79A.76.76,0,0,1,15,10.54Z" style="fill:#4285f4"/><circle cx="15.01" cy="5.69" r="0.76" style="fill:#4285f4"/><circle cx="15.01" cy="14.19" r="0.76" style="fill:#4285f4"/><circle cx="15.01" cy="11.97" r="0.76" style="fill:#4285f4"/><circle cx="8.99" cy="14.19" r="0.76" style="fill:#aecbfa"/><circle cx="8.99" cy="7.92" r="0.76" style="fill:#aecbfa"/><circle cx="8.99" cy="5.69" r="0.76" style="fill:#aecbfa"/><path d="M9,12.73A.76.76,0,0,1,8.24,12V10.1a.75.75,0,1,1,1.5,0V12A.75.75,0,0,1,9,12.73Z" style="fill:#aecbfa"/></svg>

+ 15 - 0
api/core/model_runtime/model_providers/vertex_ai/_common.py

@@ -0,0 +1,15 @@
+from core.model_runtime.errors.invoke import InvokeError
+
+
+class _CommonVertexAi:
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        """
+        Map model invoke error to unified error
+        The key is the error type thrown to the caller
+        The value is the error type thrown by the model,
+        which needs to be converted into a unified error type for the caller.
+
+        :return: Invoke error mapping
+        """
+        pass

+ 0 - 0
api/core/model_runtime/model_providers/vertex_ai/llm/__init__.py


+ 38 - 0
api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro-vision.yaml

@@ -0,0 +1,38 @@
+model: gemini-1.0-pro-vision-001
+label:
+  en_US: Gemini 1.0 Pro Vision
+model_type: llm
+features:
+  - vision
+  - tool-call
+  - stream-tool-call
+model_properties:
+  mode: chat
+  context_size: 16384
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: top_k
+    label:
+      en_US: Top k
+    type: int
+    help:
+      en_US: Only sample from the top K options for each subsequent token.
+    required: false
+  - name: presence_penalty
+    use_template: presence_penalty
+  - name: frequency_penalty
+    use_template: frequency_penalty
+  - name: max_output_tokens
+    use_template: max_tokens
+    required: true
+    default: 2048
+    min: 1
+    max: 2048
+pricing:
+  input: '0.00'
+  output: '0.00'
+  unit: '0.000001'
+  currency: USD

+ 38 - 0
api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.0-pro.yaml

@@ -0,0 +1,38 @@
+model: gemini-1.0-pro-002
+label:
+  en_US: Gemini 1.0 Pro
+model_type: llm
+features:
+  - agent-thought
+  - tool-call
+  - stream-tool-call
+model_properties:
+  mode: chat
+  context_size: 32760
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: top_k
+    label:
+      en_US: Top k
+    type: int
+    help:
+      en_US: Only sample from the top K options for each subsequent token.
+    required: false
+  - name: presence_penalty
+    use_template: presence_penalty
+  - name: frequency_penalty
+    use_template: frequency_penalty
+  - name: max_output_tokens
+    use_template: max_tokens
+    required: true
+    default: 8192
+    min: 1
+    max: 8192
+pricing:
+  input: '0.00'
+  output: '0.00'
+  unit: '0.000001'
+  currency: USD

+ 38 - 0
api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-flash.yaml

@@ -0,0 +1,38 @@
+model: gemini-1.5-flash-preview-0514
+label:
+  en_US: Gemini 1.5 Flash
+model_type: llm
+features:
+  - vision
+  - tool-call
+  - stream-tool-call
+model_properties:
+  mode: chat
+  context_size: 1048576
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: top_k
+    label:
+      en_US: Top k
+    type: int
+    help:
+      en_US: Only sample from the top K options for each subsequent token.
+    required: false
+  - name: presence_penalty
+    use_template: presence_penalty
+  - name: frequency_penalty
+    use_template: frequency_penalty
+  - name: max_output_tokens
+    use_template: max_tokens
+    required: true
+    default: 8192
+    min: 1
+    max: 8192
+pricing:
+  input: '0.00'
+  output: '0.00'
+  unit: '0.000001'
+  currency: USD

+ 39 - 0
api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-pro.yaml

@@ -0,0 +1,39 @@
+model: gemini-1.5-pro-preview-0514
+label:
+  en_US: Gemini 1.5 Pro
+model_type: llm
+features:
+  - agent-thought
+  - vision
+  - tool-call
+  - stream-tool-call
+model_properties:
+  mode: chat
+  context_size: 1048576
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: top_k
+    label:
+      en_US: Top k
+    type: int
+    help:
+      en_US: Only sample from the top K options for each subsequent token.
+    required: false
+  - name: presence_penalty
+    use_template: presence_penalty
+  - name: frequency_penalty
+    use_template: frequency_penalty
+  - name: max_output_tokens
+    use_template: max_tokens
+    required: true
+    default: 8192
+    min: 1
+    max: 8192
+pricing:
+  input: '0.00'
+  output: '0.00'
+  unit: '0.000001'
+  currency: USD

+ 438 - 0
api/core/model_runtime/model_providers/vertex_ai/llm/llm.py

@@ -0,0 +1,438 @@
+import base64
+import json
+import logging
+from collections.abc import Generator
+from typing import Optional, Union
+
+import google.api_core.exceptions as exceptions
+import vertexai.generative_models as glm
+from google.cloud import aiplatform
+from google.oauth2 import service_account
+from vertexai.generative_models import HarmBlockThreshold, HarmCategory
+
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.message_entities import (
+    AssistantPromptMessage,
+    PromptMessage,
+    PromptMessageContentType,
+    PromptMessageTool,
+    SystemPromptMessage,
+    ToolPromptMessage,
+    UserPromptMessage,
+)
+from core.model_runtime.errors.invoke import (
+    InvokeAuthorizationError,
+    InvokeBadRequestError,
+    InvokeConnectionError,
+    InvokeError,
+    InvokeRateLimitError,
+    InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+
+logger = logging.getLogger(__name__)
+
+GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
+The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
+if you are not sure about the structure.
+
+<instructions>
+{{instructions}}
+</instructions>
+"""
+
+
+class VertexAiLargeLanguageModel(LargeLanguageModel):
+
+    def _invoke(self, model: str, credentials: dict,
+                prompt_messages: list[PromptMessage], model_parameters: dict,
+                tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
+                stream: bool = True, user: Optional[str] = None) \
+            -> Union[LLMResult, Generator]:
+        """
+        Invoke large language model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param prompt_messages: prompt messages
+        :param model_parameters: model parameters
+        :param tools: tools for tool calling
+        :param stop: stop words
+        :param stream: is stream response
+        :param user: unique user id
+        :return: full response or stream response chunk generator result
+        """
+        # invoke model
+        return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
+    
+    def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+                       tools: Optional[list[PromptMessageTool]] = None) -> int:
+        """
+        Get number of tokens for given prompt messages
+
+        :param model: model name
+        :param credentials: model credentials
+        :param prompt_messages: prompt messages
+        :param tools: tools for tool calling
+        :return:md = gml.GenerativeModel(model)
+        """
+        prompt = self._convert_messages_to_prompt(prompt_messages)
+
+        return self._get_num_tokens_by_gpt2(prompt)
+    
+    def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
+        """
+        Format a list of messages into a full prompt for the Google model
+
+        :param messages: List of PromptMessage to combine.
+        :return: Combined string with necessary human_prompt and ai_prompt tags.
+        """
+        messages = messages.copy()  # don't mutate the original list
+
+        text = "".join(
+            self._convert_one_message_to_text(message)
+            for message in messages
+        )
+
+        return text.rstrip()
+    
+    def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool:
+        """
+        Convert tool messages to glm tools
+
+        :param tools: tool messages
+        :return: glm tools
+        """
+        return glm.Tool(
+            function_declarations=[
+                glm.FunctionDeclaration(
+                    name=tool.name,
+                    parameters=glm.Schema(
+                        type=glm.Type.OBJECT,
+                        properties={
+                            key: {
+                                'type_': value.get('type', 'string').upper(),
+                                'description': value.get('description', ''),
+                                'enum': value.get('enum', [])
+                            } for key, value in tool.parameters.get('properties', {}).items()
+                        },
+                        required=tool.parameters.get('required', [])
+                    ),
+                ) for tool in tools
+            ]
+        )
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate model credentials
+
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        
+        try:
+            ping_message = SystemPromptMessage(content="ping")
+            self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})
+            
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))
+            
+
+    def _generate(self, model: str, credentials: dict,
+                  prompt_messages: list[PromptMessage], model_parameters: dict,
+                  tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, 
+                  stream: bool = True, user: Optional[str] = None
+        ) -> Union[LLMResult, Generator]:
+        """
+        Invoke large language model
+
+        :param model: model name
+        :param credentials: credentials kwargs
+        :param prompt_messages: prompt messages
+        :param model_parameters: model parameters
+        :param stop: stop words
+        :param stream: is stream response
+        :param user: unique user id
+        :return: full response or stream response chunk generator result
+        """
+        config_kwargs = model_parameters.copy()
+        config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None)
+
+        if stop:
+            config_kwargs["stop_sequences"] = stop
+
+        service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
+        service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
+        project_id = credentials["vertex_project_id"]
+        location = credentials["vertex_location"]
+        aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
+
+        history = []
+        system_instruction = GEMINI_BLOCK_MODE_PROMPT
+        # hack for gemini-pro-vision, which currently does not support multi-turn chat
+        if model == "gemini-1.0-pro-vision-001":
+            last_msg = prompt_messages[-1]
+            content = self._format_message_to_glm_content(last_msg)
+            history.append(content)
+        else:
+            for msg in prompt_messages:
+                if isinstance(msg, SystemPromptMessage):
+                    system_instruction = msg.content
+                else:
+                    content = self._format_message_to_glm_content(msg)
+                    if history and history[-1].role == content.role:
+                        history[-1].parts.extend(content.parts)
+                    else:
+                        history.append(content)
+
+        safety_settings={
+            HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
+            HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
+            HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
+            HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
+        }
+
+        google_model = glm.GenerativeModel(
+            model_name=model,
+            system_instruction=system_instruction
+        )
+
+        response = google_model.generate_content(
+            contents=history,
+            generation_config=glm.GenerationConfig(
+                **config_kwargs
+            ),
+            stream=stream,
+            safety_settings=safety_settings,
+            tools=self._convert_tools_to_glm_tool(tools) if tools else None
+        )
+
+        if stream:
+            return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
+
+        return self._handle_generate_response(model, credentials, response, prompt_messages)
+
+    def _handle_generate_response(self, model: str, credentials: dict, response: glm.GenerationResponse,
+                                  prompt_messages: list[PromptMessage]) -> LLMResult:
+        """
+        Handle llm response
+
+        :param model: model name
+        :param credentials: credentials
+        :param response: response
+        :param prompt_messages: prompt messages
+        :return: llm response
+        """
+        # transform assistant message to prompt message
+        assistant_prompt_message = AssistantPromptMessage(
+            content=response.candidates[0].content.parts[0].text
+        )
+
+        # calculate num tokens
+        prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
+        completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
+
+        # transform usage
+        usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+        # transform response
+        result = LLMResult(
+            model=model,
+            prompt_messages=prompt_messages,
+            message=assistant_prompt_message,
+            usage=usage,
+        )
+
+        return result
+
+    def _handle_generate_stream_response(self, model: str, credentials: dict, response: glm.GenerationResponse,
+                                         prompt_messages: list[PromptMessage]) -> Generator:
+        """
+        Handle llm stream response
+
+        :param model: model name
+        :param credentials: credentials
+        :param response: response
+        :param prompt_messages: prompt messages
+        :return: llm response chunk generator result
+        """
+        index = -1
+        for chunk in response:
+            for part in chunk.candidates[0].content.parts:
+                assistant_prompt_message = AssistantPromptMessage(
+                    content=''
+                )
+
+                if part.text:
+                    assistant_prompt_message.content += part.text
+
+                if part.function_call:
+                    assistant_prompt_message.tool_calls = [
+                        AssistantPromptMessage.ToolCall(
+                            id=part.function_call.name,
+                            type='function',
+                            function=AssistantPromptMessage.ToolCall.ToolCallFunction(
+                                name=part.function_call.name,
+                                arguments=json.dumps({
+                                    key: value 
+                                    for key, value in part.function_call.args.items()
+                                })
+                            )
+                        )
+                    ]
+
+                index += 1
+    
+                if not hasattr(chunk, 'finish_reason') or not chunk.finish_reason:                    
+                    # transform assistant message to prompt message
+                    yield LLMResultChunk(
+                        model=model,
+                        prompt_messages=prompt_messages,
+                        delta=LLMResultChunkDelta(
+                            index=index,
+                            message=assistant_prompt_message
+                        )
+                    )
+                else:
+                    
+                    # calculate num tokens
+                    prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
+                    completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
+
+                    # transform usage
+                    usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+                    
+                    yield LLMResultChunk(
+                        model=model,
+                        prompt_messages=prompt_messages,
+                        delta=LLMResultChunkDelta(
+                            index=index,
+                            message=assistant_prompt_message,
+                            finish_reason=chunk.candidates[0].finish_reason,
+                            usage=usage
+                        )
+                    )
+
+    def _convert_one_message_to_text(self, message: PromptMessage) -> str:
+        """
+        Convert a single message to a string.
+
+        :param message: PromptMessage to convert.
+        :return: String representation of the message.
+        """
+        human_prompt = "\n\nuser:"
+        ai_prompt = "\n\nmodel:"
+
+        content = message.content
+        if isinstance(content, list):
+            content = "".join(
+                c.data for c in content if c.type != PromptMessageContentType.IMAGE
+            )
+
+        if isinstance(message, UserPromptMessage):
+            message_text = f"{human_prompt} {content}"
+        elif isinstance(message, AssistantPromptMessage):
+            message_text = f"{ai_prompt} {content}"
+        elif isinstance(message, SystemPromptMessage):
+            message_text = f"{human_prompt} {content}"
+        elif isinstance(message, ToolPromptMessage):
+            message_text = f"{human_prompt} {content}"
+        else:
+            raise ValueError(f"Got unknown type {message}")
+
+        return message_text
+
+    def _format_message_to_glm_content(self, message: PromptMessage) -> glm.Content:
+        """
+        Format a single message into glm.Content for Google API
+
+        :param message: one PromptMessage
+        :return: glm Content representation of message
+        """
+        if isinstance(message, UserPromptMessage):
+            glm_content = glm.Content(role="user", parts=[])
+
+            if (isinstance(message.content, str)):
+                glm_content = glm.Content(role="user", parts=[glm.Part.from_text(message.content)])
+            else:
+                parts = []
+                for c in message.content:
+                    if c.type == PromptMessageContentType.TEXT:
+                        parts.append(glm.Part.from_text(c.data))
+                    else:
+                        metadata, data = c.data.split(',', 1)
+                        mime_type = metadata.split(';', 1)[0].split(':')[1]
+                        blob = {"inline_data":{"mime_type":mime_type,"data":data}}
+                        parts.append(blob)
+                
+                glm_content = glm.Content(role="user", parts=[parts])
+            return glm_content
+        elif isinstance(message, AssistantPromptMessage):
+            if message.content:
+                glm_content = glm.Content(role="model", parts=[glm.Part.from_text(message.content)])
+            if message.tool_calls:
+                glm_content = glm.Content(role="model", parts=[glm.Part.from_function_response(glm.FunctionCall(
+                    name=message.tool_calls[0].function.name,
+                    args=json.loads(message.tool_calls[0].function.arguments),
+                ))])
+            return glm_content
+        elif isinstance(message, ToolPromptMessage):
+            glm_content = glm.Content(role="function", parts=[glm.Part(function_response=glm.FunctionResponse(
+                    name=message.name,
+                    response={
+                        "response": message.content
+                    }
+                ))])
+            return glm_content
+        else:
+            raise ValueError(f"Got unknown type {message}")
+    
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        """
+        Map model invoke error to unified error
+        The key is the ermd = gml.GenerativeModel(model)ror type thrown to the caller
+        The value is the md = gml.GenerativeModel(model)error type thrown by the model,
+        which needs to be converted into a unified error type for the caller.
+
+        :return: Invoke emd = gml.GenerativeModel(model)rror mapping
+        """
+        return {
+            InvokeConnectionError: [
+                exceptions.RetryError
+            ],
+            InvokeServerUnavailableError: [
+                exceptions.ServiceUnavailable,
+                exceptions.InternalServerError,
+                exceptions.BadGateway,
+                exceptions.GatewayTimeout,
+                exceptions.DeadlineExceeded
+            ],
+            InvokeRateLimitError: [
+                exceptions.ResourceExhausted,
+                exceptions.TooManyRequests
+            ],
+            InvokeAuthorizationError: [
+                exceptions.Unauthenticated,
+                exceptions.PermissionDenied,
+                exceptions.Unauthenticated,
+                exceptions.Forbidden
+            ],
+            InvokeBadRequestError: [
+                exceptions.BadRequest,
+                exceptions.InvalidArgument,
+                exceptions.FailedPrecondition,
+                exceptions.OutOfRange,
+                exceptions.NotFound,
+                exceptions.MethodNotAllowed,
+                exceptions.Conflict,
+                exceptions.AlreadyExists,
+                exceptions.Aborted,
+                exceptions.LengthRequired,
+                exceptions.PreconditionFailed,
+                exceptions.RequestRangeNotSatisfiable,
+                exceptions.Cancelled,
+            ]
+        }

+ 0 - 0
api/core/model_runtime/model_providers/vertex_ai/text_embedding/__init__.py


+ 8 - 0
api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-embedding-004.yaml

@@ -0,0 +1,8 @@
+model: text-embedding-004
+model_type: text-embedding
+model_properties:
+  context_size: 2048
+pricing:
+  input: '0.00013'
+  unit: '0.001'
+  currency: USD

+ 8 - 0
api/core/model_runtime/model_providers/vertex_ai/text_embedding/text-multilingual-embedding-002.yaml

@@ -0,0 +1,8 @@
+model: text-multilingual-embedding-002
+model_type: text-embedding
+model_properties:
+  context_size: 2048
+pricing:
+  input: '0.00013'
+  unit: '0.001'
+  currency: USD

+ 193 - 0
api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py

@@ -0,0 +1,193 @@
+import base64
+import json
+import time
+from decimal import Decimal
+from typing import Optional
+
+import tiktoken
+from google.cloud import aiplatform
+from google.oauth2 import service_account
+from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import (
+    AIModelEntity,
+    FetchFrom,
+    ModelPropertyKey,
+    ModelType,
+    PriceConfig,
+    PriceType,
+)
+from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+from core.model_runtime.model_providers.vertex_ai._common import _CommonVertexAi
+
+
+class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
+    """
+    Model class for Vertex AI text embedding model.
+    """
+
+    def _invoke(self, model: str, credentials: dict,
+                texts: list[str], user: Optional[str] = None) \
+            -> TextEmbeddingResult:
+        """
+        Invoke text embedding model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param texts: texts to embed
+        :return: embeddings result
+        """
+        service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
+        service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
+        project_id = credentials["vertex_project_id"]
+        location = credentials["vertex_location"]
+        aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
+
+        client = VertexTextEmbeddingModel.from_pretrained(model)
+
+        
+
+        embeddings_batch, embedding_used_tokens = self._embedding_invoke(
+            client=client,
+            texts=texts
+        )
+
+        # calc usage
+        usage = self._calc_response_usage(
+            model=model,
+            credentials=credentials,
+            tokens=embedding_used_tokens
+        )
+
+        return TextEmbeddingResult(
+            embeddings=embeddings_batch,
+            usage=usage,
+            model=model
+        )
+
+    def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+        """
+        Get number of tokens for given prompt messages
+
+        :param model: model name
+        :param credentials: model credentials
+        :param texts: texts to embed
+        :return:
+        """
+        if len(texts) == 0:
+            return 0
+
+        try:
+            enc = tiktoken.encoding_for_model(model)
+        except KeyError:
+            enc = tiktoken.get_encoding("cl100k_base")
+
+        total_num_tokens = 0
+        for text in texts:
+            # calculate the number of tokens in the encoded text
+            tokenized_text = enc.encode(text)
+            total_num_tokens += len(tokenized_text)
+
+        return total_num_tokens
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate model credentials
+
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        try:
+            service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
+            service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
+            project_id = credentials["vertex_project_id"]
+            location = credentials["vertex_location"]
+            aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
+
+            client = VertexTextEmbeddingModel.from_pretrained(model)
+
+            # call embedding model
+            self._embedding_invoke(
+                model=model,
+                client=client,
+                texts=['ping']
+            )
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))
+
+    def _embedding_invoke(self, client: VertexTextEmbeddingModel, texts: list[str]) -> [list[float], int]: # type: ignore
+        """
+        Invoke embedding model
+
+        :param model: model name
+        :param client: model client
+        :param texts: texts to embed
+        :return: embeddings and used tokens
+        """
+        response = client.get_embeddings(texts)
+
+        embeddings = []
+        token_usage = 0
+
+        for i in range(len(response)):
+            embeddings.append(response[i].values)
+            token_usage += int(response[i].statistics.token_count)
+
+        return embeddings, token_usage
+
+    def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
+        """
+        Calculate response usage
+
+        :param model: model name
+        :param credentials: model credentials
+        :param tokens: input tokens
+        :return: usage
+        """
+        # get input price info
+        input_price_info = self.get_price(
+            model=model,
+            credentials=credentials,
+            price_type=PriceType.INPUT,
+            tokens=tokens
+        )
+
+        # transform usage
+        usage = EmbeddingUsage(
+            tokens=tokens,
+            total_tokens=tokens,
+            unit_price=input_price_info.unit_price,
+            price_unit=input_price_info.unit,
+            total_price=input_price_info.total_amount,
+            currency=input_price_info.currency,
+            latency=time.perf_counter() - self.started_at
+        )
+
+        return usage
+    
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
+        """
+            generate custom model entities from credentials
+        """
+        entity = AIModelEntity(
+            model=model,
+            label=I18nObject(en_US=model),
+            model_type=ModelType.TEXT_EMBEDDING,
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_properties={
+                ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
+                ModelPropertyKey.MAX_CHUNKS: 1,
+            },
+            parameter_rules=[],
+            pricing=PriceConfig(
+                input=Decimal(credentials.get('input_price', 0)),
+                unit=Decimal(credentials.get('unit', 0)),
+                currency=credentials.get('currency', "USD")
+            )
+        )
+
+        return entity

+ 31 - 0
api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py

@@ -0,0 +1,31 @@
+import logging
+
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class VertexAiProvider(ModelProvider):
+    def validate_provider_credentials(self, credentials: dict) -> None:
+        """
+        Validate provider credentials
+
+        if validate failed, raise exception
+
+        :param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
+        """
+        try:
+            model_instance = self.get_model_instance(ModelType.LLM)
+
+            # Use `gemini-1.0-pro-002` model for validate,
+            model_instance.validate_credentials(
+                model='gemini-1.0-pro-002',
+                credentials=credentials
+            )
+        except CredentialsValidateFailedError as ex:
+            raise ex
+        except Exception as ex:
+            logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
+            raise ex

+ 43 - 0
api/core/model_runtime/model_providers/vertex_ai/vertex_ai.yaml

@@ -0,0 +1,43 @@
+provider: vertex_ai
+label:
+  en_US: Vertex AI | Google Cloud Platform
+description:
+  en_US: Vertex AI in Google Cloud Platform.
+icon_small:
+  en_US: icon_s_en.svg
+icon_large:
+  en_US: icon_l_en.png
+background: "#FCFDFF"
+help:
+  title:
+    en_US: Get your Access Details from Google
+  url:
+    en_US: https://cloud.google.com/vertex-ai/
+supported_model_types:
+  - llm
+  - text-embedding
+configurate_methods:
+  - predefined-model
+provider_credential_schema:
+  credential_form_schemas:
+    - variable: vertex_project_id
+      label:
+        en_US: Project ID
+      type: text-input
+      required: true
+      placeholder:
+        en_US: Enter your Google Cloud Project ID
+    - variable: vertex_location
+      label:
+        en_US: Location
+      type: text-input
+      required: true
+      placeholder:
+        en_US: Enter your Google Cloud Location
+    - variable: vertex_service_account_key
+      label:
+        en_US: Service Account Key
+      type: secret-input
+      required: true
+      placeholder:
+        en_US: Enter your Google Cloud Service Account Key in base64 format

+ 1 - 0
api/requirements.txt

@@ -84,3 +84,4 @@ pgvecto-rs==0.1.4
 firecrawl-py==0.0.5
 oss2==2.18.5
 pgvector==0.2.5
+google-cloud-aiplatform==1.49.0