瀏覽代碼

Add OCI(Oracle Cloud Infrastructure) Generative AI Service as a Model Provider (#7775)

Co-authored-by: Walter Jin <jinshuhaicc@gmail.com>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: walter from vm <walter.jin@oracle.com>
tmuife 7 月之前
父節點
當前提交
89aede80cc
共有 23 個文件被更改,包括 1678 次插入430 次删除
  1. 0 0
      api/core/model_runtime/model_providers/oci/__init__.py
  2. 1 0
      api/core/model_runtime/model_providers/oci/_assets/icon_l_en.svg
  3. 1 0
      api/core/model_runtime/model_providers/oci/_assets/icon_s_en.svg
  4. 52 0
      api/core/model_runtime/model_providers/oci/llm/cohere.command-r-16k.yaml
  5. 52 0
      api/core/model_runtime/model_providers/oci/llm/cohere.command-r-plus.yaml
  6. 461 0
      api/core/model_runtime/model_providers/oci/llm/llm.py
  7. 51 0
      api/core/model_runtime/model_providers/oci/llm/meta.llama-3-70b-instruct.yaml
  8. 34 0
      api/core/model_runtime/model_providers/oci/oci.py
  9. 42 0
      api/core/model_runtime/model_providers/oci/oci.yaml
  10. 0 0
      api/core/model_runtime/model_providers/oci/text_embedding/__init__.py
  11. 5 0
      api/core/model_runtime/model_providers/oci/text_embedding/_position.yaml
  12. 9 0
      api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-english-light-v2.0.yaml
  13. 9 0
      api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-english-light-v3.0.yaml
  14. 9 0
      api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-english-v3.0.yaml
  15. 9 0
      api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-multilingual-light-v3.0.yaml
  16. 9 0
      api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-multilingual-v3.0.yaml
  17. 242 0
      api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py
  18. 483 430
      api/poetry.lock
  19. 1 0
      api/pyproject.toml
  20. 0 0
      api/tests/integration_tests/model_runtime/oci/__init__.py
  21. 130 0
      api/tests/integration_tests/model_runtime/oci/test_llm.py
  22. 20 0
      api/tests/integration_tests/model_runtime/oci/test_provider.py
  23. 58 0
      api/tests/integration_tests/model_runtime/oci/test_text_embedding.py

+ 0 - 0
api/core/model_runtime/model_providers/oci/__init__.py


+ 1 - 0
api/core/model_runtime/model_providers/oci/_assets/icon_l_en.svg

@@ -0,0 +1 @@
+<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 231 30' preserveAspectRatio='xMinYMid'><path d='M99.61,19.52h15.24l-8.05-13L92,30H85.27l18-28.17a4.29,4.29,0,0,1,7-.05L128.32,30h-6.73l-3.17-5.25H103l-3.36-5.23m69.93,5.23V0.28h-5.72V27.16a2.76,2.76,0,0,0,.85,2,2.89,2.89,0,0,0,2.08.87h26l3.39-5.25H169.54M75,20.38A10,10,0,0,0,75,.28H50V30h5.71V5.54H74.65a4.81,4.81,0,0,1,0,9.62H58.54L75.6,30h8.29L72.43,20.38H75M14.88,30H32.15a14.86,14.86,0,0,0,0-29.71H14.88a14.86,14.86,0,1,0,0,29.71m16.88-5.23H15.26a9.62,9.62,0,0,1,0-19.23h16.5a9.62,9.62,0,1,1,0,19.23M140.25,30h17.63l3.34-5.23H140.64a9.62,9.62,0,1,1,0-19.23h16.75l3.38-5.25H140.25a14.86,14.86,0,1,0,0,29.71m69.87-5.23a9.62,9.62,0,0,1-9.26-7h24.42l3.36-5.24H200.86a9.61,9.61,0,0,1,9.26-7h16.76l3.35-5.25h-20.5a14.86,14.86,0,0,0,0,29.71h17.63l3.35-5.23h-20.6' transform='translate(-0.02 0)' style='fill:#C74634'/></svg>

+ 1 - 0
api/core/model_runtime/model_providers/oci/_assets/icon_s_en.svg

@@ -0,0 +1 @@
+<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 231 30' preserveAspectRatio='xMinYMid'><path d='M99.61,19.52h15.24l-8.05-13L92,30H85.27l18-28.17a4.29,4.29,0,0,1,7-.05L128.32,30h-6.73l-3.17-5.25H103l-3.36-5.23m69.93,5.23V0.28h-5.72V27.16a2.76,2.76,0,0,0,.85,2,2.89,2.89,0,0,0,2.08.87h26l3.39-5.25H169.54M75,20.38A10,10,0,0,0,75,.28H50V30h5.71V5.54H74.65a4.81,4.81,0,0,1,0,9.62H58.54L75.6,30h8.29L72.43,20.38H75M14.88,30H32.15a14.86,14.86,0,0,0,0-29.71H14.88a14.86,14.86,0,1,0,0,29.71m16.88-5.23H15.26a9.62,9.62,0,0,1,0-19.23h16.5a9.62,9.62,0,1,1,0,19.23M140.25,30h17.63l3.34-5.23H140.64a9.62,9.62,0,1,1,0-19.23h16.75l3.38-5.25H140.25a14.86,14.86,0,1,0,0,29.71m69.87-5.23a9.62,9.62,0,0,1-9.26-7h24.42l3.36-5.24H200.86a9.61,9.61,0,0,1,9.26-7h16.76l3.35-5.25h-20.5a14.86,14.86,0,0,0,0,29.71h17.63l3.35-5.23h-20.6' transform='translate(-0.02 0)' style='fill:#C74634'/></svg>

+ 52 - 0
api/core/model_runtime/model_providers/oci/llm/cohere.command-r-16k.yaml

@@ -0,0 +1,52 @@
+model: cohere.command-r-16k
+label:
+  en_US: cohere.command-r-16k v1.2
+model_type: llm
+features:
+  - multi-tool-call
+  - agent-thought
+  - stream-tool-call
+model_properties:
+  mode: chat
+  context_size: 128000
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+    default: 1
+    max: 1.0
+  - name: topP
+    use_template: top_p
+    default: 0.75
+    min: 0
+    max: 1
+  - name: topK
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+    required: false
+    default: 0
+    min: 0
+    max: 500
+  - name: presencePenalty
+    use_template: presence_penalty
+    min: 0
+    max: 1
+    default: 0
+  - name: frequencyPenalty
+    use_template: frequency_penalty
+    min: 0
+    max: 1
+    default: 0
+  - name: maxTokens
+    use_template: max_tokens
+    default: 600
+    max: 4000
+pricing:
+  input: '0.004'
+  output: '0.004'
+  unit: '0.0001'
+  currency: USD

+ 52 - 0
api/core/model_runtime/model_providers/oci/llm/cohere.command-r-plus.yaml

@@ -0,0 +1,52 @@
+model: cohere.command-r-plus
+label:
+  en_US: cohere.command-r-plus v1.2
+model_type: llm
+features:
+  - multi-tool-call
+  - agent-thought
+  - stream-tool-call
+model_properties:
+  mode: chat
+  context_size: 128000
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+    default: 1
+    max: 1.0
+  - name: topP
+    use_template: top_p
+    default: 0.75
+    min: 0
+    max: 1
+  - name: topK
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+    required: false
+    default: 0
+    min: 0
+    max: 500
+  - name: presencePenalty
+    use_template: presence_penalty
+    min: 0
+    max: 1
+    default: 0
+  - name: frequencyPenalty
+    use_template: frequency_penalty
+    min: 0
+    max: 1
+    default: 0
+  - name: maxTokens
+    use_template: max_tokens
+    default: 600
+    max: 4000
+pricing:
+  input: '0.0219'
+  output: '0.0219'
+  unit: '0.0001'
+  currency: USD

+ 461 - 0
api/core/model_runtime/model_providers/oci/llm/llm.py

@@ -0,0 +1,461 @@
+import base64
+import copy
+import json
+import logging
+from collections.abc import Generator
+from typing import Optional, Union
+
+import oci
+from oci.generative_ai_inference.models.base_chat_response import BaseChatResponse
+
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.message_entities import (
+    AssistantPromptMessage,
+    PromptMessage,
+    PromptMessageContentType,
+    PromptMessageTool,
+    SystemPromptMessage,
+    ToolPromptMessage,
+    UserPromptMessage,
+)
+from core.model_runtime.errors.invoke import (
+    InvokeAuthorizationError,
+    InvokeBadRequestError,
+    InvokeConnectionError,
+    InvokeError,
+    InvokeRateLimitError,
+    InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+
+logger = logging.getLogger(__name__)
+
+request_template = {
+    "compartmentId": "",
+    "servingMode": {
+        "modelId": "cohere.command-r-plus",
+        "servingType": "ON_DEMAND"
+    },
+    "chatRequest": {
+        "apiFormat": "COHERE",
+        #"preambleOverride": "You are a helpful assistant.",
+        #"message": "Hello!",
+        #"chatHistory": [],
+        "maxTokens": 600,
+        "isStream": False,
+        "frequencyPenalty": 0,
+        "presencePenalty": 0,
+        "temperature": 1,
+        "topP": 0.75
+    }
+}
+oci_config_template = {
+        "user": "",
+        "fingerprint": "",
+        "tenancy": "",
+        "region": "",
+        "compartment_id": "",
+        "key_content": ""
+    }
+
+class OCILargeLanguageModel(LargeLanguageModel):
+    # https://docs.oracle.com/en-us/iaas/Content/generative-ai/pretrained-models.htm
+    _supported_models = {
+        "meta.llama-3-70b-instruct": {
+            "system": True,
+            "multimodal": False,
+            "tool_call": False,
+            "stream_tool_call": False,
+        },
+        "cohere.command-r-16k": {
+            "system": True,
+            "multimodal": False,
+            "tool_call": True,
+            "stream_tool_call": False,
+        },
+        "cohere.command-r-plus": {
+            "system": True,
+            "multimodal": False,
+            "tool_call": True,
+            "stream_tool_call": False,
+        },
+    }
+
+    def _is_tool_call_supported(self, model_id: str, stream: bool = False) -> bool:
+        feature = self._supported_models.get(model_id)
+        if not feature:
+            return False
+        return feature["stream_tool_call"] if stream else feature["tool_call"]
+
+    def _is_multimodal_supported(self, model_id: str) -> bool:
+        feature = self._supported_models.get(model_id)
+        if not feature:
+            return False
+        return feature["multimodal"]
+
+    def _is_system_prompt_supported(self, model_id: str) -> bool:
+        feature = self._supported_models.get(model_id)
+        if not feature:
+            return False
+        return feature["system"]
+
+    def _invoke(self, model: str, credentials: dict,
+                prompt_messages: list[PromptMessage], model_parameters: dict,
+                tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
+                stream: bool = True, user: Optional[str] = None) \
+            -> Union[LLMResult, Generator]:
+        """
+        Invoke large language model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param prompt_messages: prompt messages
+        :param model_parameters: model parameters
+        :param tools: tools for tool calling
+        :param stop: stop words
+        :param stream: is stream response
+        :param user: unique user id
+        :return: full response or stream response chunk generator result
+        """
+        #print("model"+"*"*20)
+        #print(model)
+        #print("credentials"+"*"*20)
+        #print(credentials)
+        #print("model_parameters"+"*"*20)
+        #print(model_parameters)
+        #print("prompt_messages"+"*"*200)
+        #print(prompt_messages)
+        #print("tools"+"*"*20)
+        #print(tools)
+
+        # invoke model
+        return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
+
+    def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+                       tools: Optional[list[PromptMessageTool]] = None) -> int:
+        """
+        Get number of tokens for given prompt messages
+
+        :param model: model name
+        :param credentials: model credentials
+        :param prompt_messages: prompt messages
+        :param tools: tools for tool calling
+        :return:md = genai.GenerativeModel(model)
+        """
+        prompt = self._convert_messages_to_prompt(prompt_messages)
+
+        return self._get_num_tokens_by_gpt2(prompt)
+
+    def get_num_characters(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+                       tools: Optional[list[PromptMessageTool]] = None) -> int:
+        """
+        Get number of tokens for given prompt messages
+
+        :param model: model name
+        :param credentials: model credentials
+        :param prompt_messages: prompt messages
+        :param tools: tools for tool calling
+        :return:md = genai.GenerativeModel(model)
+        """
+        prompt = self._convert_messages_to_prompt(prompt_messages)
+
+        return len(prompt)
+
+    def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
+        """
+        :param messages: List of PromptMessage to combine.
+        :return: Combined string with necessary human_prompt and ai_prompt tags.
+        """
+        messages = messages.copy()  # don't mutate the original list
+
+        text = "".join(
+            self._convert_one_message_to_text(message)
+            for message in messages
+        )
+
+        return text.rstrip()
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate model credentials
+
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        # Setup basic variables
+        # Auth Config
+        try:
+            ping_message = SystemPromptMessage(content="ping")
+            self._generate(model, credentials, [ping_message], {"maxTokens": 5})
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))
+
+    def _generate(self, model: str, credentials: dict,
+                  prompt_messages: list[PromptMessage], model_parameters: dict,
+                  tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
+                  stream: bool = True, user: Optional[str] = None
+                  ) -> Union[LLMResult, Generator]:
+        """
+        Invoke large language model
+
+        :param model: model name
+        :param credentials: credentials kwargs
+        :param prompt_messages: prompt messages
+        :param model_parameters: model parameters
+        :param stop: stop words
+        :param stream: is stream response
+        :param user: unique user id
+        :return: full response or stream response chunk generator result
+        """
+        # config_kwargs = model_parameters.copy()
+        # config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None)
+        # if stop:
+        #    config_kwargs["stop_sequences"] = stop
+
+        # initialize client
+        # ref: https://docs.oracle.com/en-us/iaas/api/#/en/generative-ai-inference/20231130/ChatResult/Chat
+        oci_config = copy.deepcopy(oci_config_template)
+        if "oci_config_content" in credentials:
+            oci_config_content = base64.b64decode(credentials.get('oci_config_content')).decode('utf-8')
+            config_items = oci_config_content.split("/")
+            if len(config_items) != 5:
+                raise CredentialsValidateFailedError("oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))")
+            oci_config["user"] = config_items[0]
+            oci_config["fingerprint"] = config_items[1]
+            oci_config["tenancy"] = config_items[2]
+            oci_config["region"] = config_items[3]
+            oci_config["compartment_id"] = config_items[4]
+        else:
+            raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
+        if "oci_key_content" in credentials:
+            oci_key_content = base64.b64decode(credentials.get('oci_key_content')).decode('utf-8')
+            oci_config["key_content"] = oci_key_content.encode(encoding="utf-8")
+        else:
+            raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
+
+        #oci_config = oci.config.from_file('~/.oci/config', credentials.get('oci_api_profile'))
+        compartment_id = oci_config["compartment_id"]
+        client = oci.generative_ai_inference.GenerativeAiInferenceClient(config=oci_config)
+        # call embedding model
+        request_args = copy.deepcopy(request_template)
+        request_args["compartmentId"] = compartment_id
+        request_args["servingMode"]["modelId"] = model
+
+        chathistory = []
+        system_prompts = []
+        #if "meta.llama" in model:
+        #    request_args["chatRequest"]["apiFormat"] = "GENERIC"
+        request_args["chatRequest"]["maxTokens"] = model_parameters.pop('maxTokens', 600)
+        request_args["chatRequest"].update(model_parameters)
+        frequency_penalty = model_parameters.get("frequencyPenalty", 0)
+        presence_penalty = model_parameters.get("presencePenalty", 0)
+        if frequency_penalty > 0 and presence_penalty > 0:
+            raise InvokeBadRequestError("Cannot set both frequency penalty and presence penalty")
+
+        # for msg in prompt_messages:  # makes message roles strictly alternating
+        #    content = self._format_message_to_glm_content(msg)
+        #    if history and history[-1]["role"] == content["role"]:
+        #        history[-1]["parts"].extend(content["parts"])
+        #    else:
+        #        history.append(content)
+
+        # temporary not implement the tool call function
+        valid_value = self._is_tool_call_supported(model, stream)
+        if tools is not None and len(tools) > 0:
+            if not valid_value:
+                raise InvokeBadRequestError("Does not support function calling")
+        if model.startswith("cohere"):
+            #print("run cohere " * 10)
+            for message in prompt_messages[:-1]:
+                text = ""
+                if isinstance(message.content, str):
+                    text = message.content
+                if isinstance(message, UserPromptMessage):
+                    chathistory.append({"role": "USER", "message": text})
+                else:
+                    chathistory.append({"role": "CHATBOT", "message": text})
+                if isinstance(message, SystemPromptMessage):
+                    if isinstance(message.content, str):
+                        system_prompts.append(message.content)
+            args = {"apiFormat": "COHERE",
+                    "preambleOverride": ' '.join(system_prompts),
+                    "message": prompt_messages[-1].content,
+                    "chatHistory": chathistory, }
+            request_args["chatRequest"].update(args)
+        elif model.startswith("meta"):
+            #print("run meta " * 10)
+            meta_messages = []
+            for message in prompt_messages:
+                text = message.content
+                meta_messages.append({"role": message.role.name, "content": [{"type": "TEXT", "text": text}]})
+            args = {"apiFormat": "GENERIC",
+                    "messages": meta_messages,
+                    "numGenerations": 1,
+                    "topK": -1}
+            request_args["chatRequest"].update(args)
+
+        if stream:
+            request_args["chatRequest"]["isStream"] = True
+        #print("final request" + "|" * 20)
+        #print(request_args)
+        response = client.chat(request_args)
+        #print(vars(response))
+
+        if stream:
+            return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
+
+        return self._handle_generate_response(model, credentials, response, prompt_messages)
+
+    def _handle_generate_response(self, model: str, credentials: dict, response: BaseChatResponse,
+                                  prompt_messages: list[PromptMessage]) -> LLMResult:
+        """
+        Handle llm response
+
+        :param model: model name
+        :param credentials: credentials
+        :param response: response
+        :param prompt_messages: prompt messages
+        :return: llm response
+        """
+        # transform assistant message to prompt message
+        assistant_prompt_message = AssistantPromptMessage(
+            content=response.data.chat_response.text
+        )
+
+        # calculate num tokens
+        prompt_tokens = self.get_num_characters(model, credentials, prompt_messages)
+        completion_tokens = self.get_num_characters(model, credentials, [assistant_prompt_message])
+
+        # transform usage
+        usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+        # transform response
+        result = LLMResult(
+            model=model,
+            prompt_messages=prompt_messages,
+            message=assistant_prompt_message,
+            usage=usage,
+        )
+
+        return result
+
+    def _handle_generate_stream_response(self, model: str, credentials: dict, response: BaseChatResponse,
+                                         prompt_messages: list[PromptMessage]) -> Generator:
+        """
+        Handle llm stream response
+
+        :param model: model name
+        :param credentials: credentials
+        :param response: response
+        :param prompt_messages: prompt messages
+        :return: llm response chunk generator result
+        """
+        index = -1
+        events = response.data.events()
+        for stream in events:
+            chunk = json.loads(stream.data)
+            #print(chunk)
+            #chunk: {'apiFormat': 'COHERE', 'text': 'Hello'}
+
+
+
+        #for chunk in response:
+            #for part in chunk.parts:
+            #if part.function_call:
+            #    assistant_prompt_message.tool_calls = [
+            #        AssistantPromptMessage.ToolCall(
+            #            id=part.function_call.name,
+            #            type='function',
+            #            function=AssistantPromptMessage.ToolCall.ToolCallFunction(
+            #                name=part.function_call.name,
+            #                arguments=json.dumps(dict(part.function_call.args.items()))
+            #            )
+            #        )
+            #    ]
+
+            if "finishReason" not in chunk:
+                assistant_prompt_message = AssistantPromptMessage(
+                    content=''
+                )
+                if model.startswith("cohere"):
+                    if chunk["text"]:
+                        assistant_prompt_message.content += chunk["text"]
+                elif model.startswith("meta"):
+                    assistant_prompt_message.content += chunk["message"]["content"][0]["text"]
+                index += 1
+                # transform assistant message to prompt message
+                yield LLMResultChunk(
+                    model=model,
+                    prompt_messages=prompt_messages,
+                    delta=LLMResultChunkDelta(
+                        index=index,
+                        message=assistant_prompt_message
+                    )
+                )
+            else:
+                # calculate num tokens
+                prompt_tokens = self.get_num_characters(model, credentials, prompt_messages)
+                completion_tokens = self.get_num_characters(model, credentials, [assistant_prompt_message])
+
+                # transform usage
+                usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+                yield LLMResultChunk(
+                    model=model,
+                    prompt_messages=prompt_messages,
+                    delta=LLMResultChunkDelta(
+                        index=index,
+                        message=assistant_prompt_message,
+                        finish_reason=str(chunk["finishReason"]),
+                        usage=usage
+                    )
+                )
+
+    def _convert_one_message_to_text(self, message: PromptMessage) -> str:
+        """
+        Convert a single message to a string.
+
+        :param message: PromptMessage to convert.
+        :return: String representation of the message.
+        """
+        human_prompt = "\n\nuser:"
+        ai_prompt = "\n\nmodel:"
+
+        content = message.content
+        if isinstance(content, list):
+            content = "".join(
+                c.data for c in content if c.type != PromptMessageContentType.IMAGE
+            )
+
+        if isinstance(message, UserPromptMessage):
+            message_text = f"{human_prompt} {content}"
+        elif isinstance(message, AssistantPromptMessage):
+            message_text = f"{ai_prompt} {content}"
+        elif isinstance(message, SystemPromptMessage):
+            message_text = f"{human_prompt} {content}"
+        elif isinstance(message, ToolPromptMessage):
+            message_text = f"{human_prompt} {content}"
+        else:
+            raise ValueError(f"Got unknown type {message}")
+
+        return message_text
+
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        """
+        Map model invoke error to unified error
+        The key is the error type thrown to the caller
+        The value is the error type thrown by the model,
+        which needs to be converted into a unified error type for the caller.
+
+        :return: Invoke error mapping
+        """
+        return {
+            InvokeConnectionError: [],
+            InvokeServerUnavailableError: [],
+            InvokeRateLimitError: [],
+            InvokeAuthorizationError: [],
+            InvokeBadRequestError: []
+        }

+ 51 - 0
api/core/model_runtime/model_providers/oci/llm/meta.llama-3-70b-instruct.yaml

@@ -0,0 +1,51 @@
+model: meta.llama-3-70b-instruct
+label:
+  zh_Hans: meta.llama-3-70b-instruct
+  en_US: meta.llama-3-70b-instruct
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: chat
+  context_size: 131072
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+    default: 1
+    max: 2.0
+  - name: topP
+    use_template: top_p
+    default: 0.75
+    min: 0
+    max: 1
+  - name: topK
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+    required: false
+    default: 0
+    min: 0
+    max: 500
+  - name: presencePenalty
+    use_template: presence_penalty
+    min: -2
+    max: 2
+    default: 0
+  - name: frequencyPenalty
+    use_template: frequency_penalty
+    min: -2
+    max: 2
+    default: 0
+  - name: maxTokens
+    use_template: max_tokens
+    default: 600
+    max: 8000
+pricing:
+  input: '0.015'
+  output: '0.015'
+  unit: '0.0001'
+  currency: USD

+ 34 - 0
api/core/model_runtime/model_providers/oci/oci.py

@@ -0,0 +1,34 @@
+import logging
+
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class OCIGENAIProvider(ModelProvider):
+
+    def validate_provider_credentials(self, credentials: dict) -> None:
+        """
+        Validate provider credentials
+
+        if validate failed, raise exception
+
+        :param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
+        """
+        try:
+            model_instance = self.get_model_instance(ModelType.LLM)
+
+            # Use `cohere.command-r-plus` model for validate,
+            model_instance.validate_credentials(
+                model='cohere.command-r-plus',
+                credentials=credentials
+            )
+        except CredentialsValidateFailedError as ex:
+            raise ex
+        except Exception as ex:
+            logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
+            raise ex
+
+

+ 42 - 0
api/core/model_runtime/model_providers/oci/oci.yaml

@@ -0,0 +1,42 @@
+provider: oci
+label:
+  en_US: OCIGenerativeAI
+description:
+  en_US: Models provided by OCI, such as Cohere Command R and Cohere Command R+.
+  zh_Hans: OCI 提供的模型,例如 Cohere Command R 和 Cohere Command R+。
+icon_small:
+  en_US: icon_s_en.svg
+icon_large:
+  en_US: icon_l_en.svg
+background: "#FFFFFF"
+help:
+  title:
+    en_US: Get your API Key from OCI
+    zh_Hans: 从 OCI 获取 API Key
+  url:
+    en_US: https://docs.cloud.oracle.com/Content/API/Concepts/sdkconfig.htm
+supported_model_types:
+  - llm
+  - text-embedding
+  #- rerank
+configurate_methods:
+  - predefined-model
+  #- customizable-model
+provider_credential_schema:
+  credential_form_schemas:
+    - variable: oci_config_content
+      label:
+        en_US: oci api key config file's content
+      type: text-input
+      required: true
+      placeholder:
+        zh_Hans: 在此输入您的 oci api key config 文件的内容(base64.b64encode("user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid".encode('utf-8')) )
+        en_US: Enter your oci api key config file's content(base64.b64encode("user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid".encode('utf-8')) )
+    - variable: oci_key_content
+      label:
+        en_US: oci api key file's content
+      type: text-input
+      required: true
+      placeholder:
+        zh_Hans: 在此输入您的 oci api key 文件的内容(base64.b64encode("pem file content".encode('utf-8')))
+        en_US: Enter your oci api key file's content(base64.b64encode("pem file content".encode('utf-8')))

+ 0 - 0
api/core/model_runtime/model_providers/oci/text_embedding/__init__.py


+ 5 - 0
api/core/model_runtime/model_providers/oci/text_embedding/_position.yaml

@@ -0,0 +1,5 @@
+- cohere.embed-english-light-v2.0
+- cohere.embed-english-light-v3.0
+- cohere.embed-english-v3.0
+- cohere.embed-multilingual-light-v3.0
+- cohere.embed-multilingual-v3.0

+ 9 - 0
api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-english-light-v2.0.yaml

@@ -0,0 +1,9 @@
+model: cohere.embed-english-light-v2.0
+model_type: text-embedding
+model_properties:
+  context_size: 1024
+  max_chunks: 48
+pricing:
+  input: '0.001'
+  unit: '0.0001'
+  currency: USD

+ 9 - 0
api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-english-light-v3.0.yaml

@@ -0,0 +1,9 @@
+model: cohere.embed-english-light-v3.0
+model_type: text-embedding
+model_properties:
+  context_size: 384
+  max_chunks: 48
+pricing:
+  input: '0.001'
+  unit: '0.0001'
+  currency: USD

+ 9 - 0
api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-english-v3.0.yaml

@@ -0,0 +1,9 @@
+model: cohere.embed-english-v3.0
+model_type: text-embedding
+model_properties:
+  context_size: 1024
+  max_chunks: 48
+pricing:
+  input: '0.001'
+  unit: '0.0001'
+  currency: USD

+ 9 - 0
api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-multilingual-light-v3.0.yaml

@@ -0,0 +1,9 @@
+model: cohere.embed-multilingual-light-v3.0
+model_type: text-embedding
+model_properties:
+  context_size: 384
+  max_chunks: 48
+pricing:
+  input: '0.001'
+  unit: '0.0001'
+  currency: USD

+ 9 - 0
api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-multilingual-v3.0.yaml

@@ -0,0 +1,9 @@
+model: cohere.embed-multilingual-v3.0
+model_type: text-embedding
+model_properties:
+  context_size: 1024
+  max_chunks: 48
+pricing:
+  input: '0.001'
+  unit: '0.0001'
+  currency: USD

+ 242 - 0
api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py

@@ -0,0 +1,242 @@
+import base64
+import copy
+import time
+from typing import Optional
+
+import numpy as np
+import oci
+
+from core.model_runtime.entities.model_entities import PriceType
+from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
+from core.model_runtime.errors.invoke import (
+    InvokeAuthorizationError,
+    InvokeBadRequestError,
+    InvokeConnectionError,
+    InvokeError,
+    InvokeRateLimitError,
+    InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+
+request_template = {
+    "compartmentId": "",
+    "servingMode": {
+        "modelId": "cohere.embed-english-light-v3.0",
+        "servingType": "ON_DEMAND"
+    },
+    "truncate": "NONE",
+    "inputs": [""]
+}
+oci_config_template = {
+        "user": "",
+        "fingerprint": "",
+        "tenancy": "",
+        "region": "",
+        "compartment_id": "",
+        "key_content": ""
+    }
+class OCITextEmbeddingModel(TextEmbeddingModel):
+    """
+    Model class for Cohere text embedding model.
+    """
+
+    def _invoke(self, model: str, credentials: dict,
+                texts: list[str], user: Optional[str] = None) \
+            -> TextEmbeddingResult:
+        """
+        Invoke text embedding model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param texts: texts to embed
+        :param user: unique user id
+        :return: embeddings result
+        """
+        # get model properties
+        context_size = self._get_context_size(model, credentials)
+        max_chunks = self._get_max_chunks(model, credentials)
+
+        inputs = []
+        indices = []
+        used_tokens = 0
+
+        for i, text in enumerate(texts):
+
+            # Here token count is only an approximation based on the GPT2 tokenizer
+            num_tokens = self._get_num_tokens_by_gpt2(text)
+
+            if num_tokens >= context_size:
+                cutoff = int(len(text) * (np.floor(context_size / num_tokens)))
+                # if num tokens is larger than context length, only use the start
+                inputs.append(text[0: cutoff])
+            else:
+                inputs.append(text)
+            indices += [i]
+
+        batched_embeddings = []
+        _iter = range(0, len(inputs), max_chunks)
+
+        for i in _iter:
+            # call embedding model
+            embeddings_batch, embedding_used_tokens = self._embedding_invoke(
+                model=model,
+                credentials=credentials,
+                texts=inputs[i: i + max_chunks]
+            )
+
+            used_tokens += embedding_used_tokens
+            batched_embeddings += embeddings_batch
+
+        # calc usage
+        usage = self._calc_response_usage(
+            model=model,
+            credentials=credentials,
+            tokens=used_tokens
+        )
+
+        return TextEmbeddingResult(
+            embeddings=batched_embeddings,
+            usage=usage,
+            model=model
+        )
+
+    def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+        """
+        Get number of tokens for given prompt messages
+
+        :param model: model name
+        :param credentials: model credentials
+        :param texts: texts to embed
+        :return:
+        """
+        return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
+
+    def get_num_characters(self, model: str, credentials: dict, texts: list[str]) -> int:
+        """
+        Get number of tokens for given prompt messages
+
+        :param model: model name
+        :param credentials: model credentials
+        :param texts: texts to embed
+        :return:
+        """
+        characters = 0
+        for text in texts:
+            characters += len(text)
+        return characters
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate model credentials
+
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        try:
+            # call embedding model
+            self._embedding_invoke(
+                model=model,
+                credentials=credentials,
+                texts=['ping']
+            )
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))
+
+    def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) -> tuple[list[list[float]], int]:
+        """
+        Invoke embedding model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param texts: texts to embed
+        :return: embeddings and used tokens
+        """
+
+        # oci
+        # initialize client
+        oci_config = copy.deepcopy(oci_config_template)
+        if "oci_config_content" in credentials:
+            oci_config_content = base64.b64decode(credentials.get('oci_config_content')).decode('utf-8')
+            config_items = oci_config_content.split("/")
+            if len(config_items) != 5:
+                raise CredentialsValidateFailedError("oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))")
+            oci_config["user"] = config_items[0]
+            oci_config["fingerprint"] = config_items[1]
+            oci_config["tenancy"] = config_items[2]
+            oci_config["region"] = config_items[3]
+            oci_config["compartment_id"] = config_items[4]
+        else:
+            raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
+        if "oci_key_content" in credentials:
+            oci_key_content = base64.b64decode(credentials.get('oci_key_content')).decode('utf-8')
+            oci_config["key_content"] = oci_key_content.encode(encoding="utf-8")
+        else:
+            raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
+        # oci_config = oci.config.from_file('~/.oci/config', credentials.get('oci_api_profile'))
+        compartment_id = oci_config["compartment_id"]
+        client = oci.generative_ai_inference.GenerativeAiInferenceClient(config=oci_config)
+        # call embedding model
+        request_args = copy.deepcopy(request_template)
+        request_args["compartmentId"] = compartment_id
+        request_args["servingMode"]["modelId"] = model
+        request_args["inputs"] = texts
+        response = client.embed_text(request_args)
+        return response.data.embeddings, self.get_num_characters(model=model, credentials=credentials, texts=texts)
+
+    def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
+        """
+        Calculate response usage
+
+        :param model: model name
+        :param credentials: model credentials
+        :param tokens: input tokens
+        :return: usage
+        """
+        # get input price info
+        input_price_info = self.get_price(
+            model=model,
+            credentials=credentials,
+            price_type=PriceType.INPUT,
+            tokens=tokens
+        )
+
+        # transform usage
+        usage = EmbeddingUsage(
+            tokens=tokens,
+            total_tokens=tokens,
+            unit_price=input_price_info.unit_price,
+            price_unit=input_price_info.unit,
+            total_price=input_price_info.total_amount,
+            currency=input_price_info.currency,
+            latency=time.perf_counter() - self.started_at
+        )
+
+        return usage
+
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        """
+        Map model invoke error to unified error
+        The key is the error type thrown to the caller
+        The value is the error type thrown by the model,
+        which needs to be converted into a unified error type for the caller.
+        :return: Invoke error mapping
+        """
+        return {
+            InvokeConnectionError: [
+                InvokeConnectionError
+            ],
+            InvokeServerUnavailableError: [
+                InvokeServerUnavailableError
+            ],
+            InvokeRateLimitError: [
+                InvokeRateLimitError
+            ],
+            InvokeAuthorizationError: [
+                InvokeAuthorizationError
+            ],
+            InvokeBadRequestError: [
+                KeyError
+            ]
+        }

File diff suppressed because it is too large
+ 483 - 430
api/poetry.lock


+ 1 - 0
api/pyproject.toml

@@ -190,6 +190,7 @@ zhipuai = "1.0.7"
 azure-ai-ml = "^1.19.0"
 azure-ai-inference = "^1.0.0b3"
 volcengine-python-sdk = {extras = ["ark"], version = "^1.0.98"}
+oci = "^2.133.0"
 [tool.poetry.group.indriect.dependencies]
 kaleido = "0.2.1"
 rank-bm25 = "~0.2.2"

+ 0 - 0
api/tests/integration_tests/model_runtime/oci/__init__.py


+ 130 - 0
api/tests/integration_tests/model_runtime/oci/test_llm.py

@@ -0,0 +1,130 @@
+import os
+from collections.abc import Generator
+
+import pytest
+
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.message_entities import (
+    AssistantPromptMessage,
+    PromptMessageTool,
+    SystemPromptMessage,
+    TextPromptMessageContent,
+    UserPromptMessage,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.oci.llm.llm import OCILargeLanguageModel
+
+
+def test_validate_credentials():
+    model = OCILargeLanguageModel()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        model.validate_credentials(
+            model="cohere.command-r-plus",
+            credentials={"oci_config_content": "invalid_key", "oci_key_content": "invalid_key"},
+        )
+
+    model.validate_credentials(
+        model="cohere.command-r-plus",
+        credentials={
+            "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
+            "oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
+        },
+    )
+
+
+def test_invoke_model():
+    model = OCILargeLanguageModel()
+
+    response = model.invoke(
+        model="cohere.command-r-plus",
+        credentials={
+            "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
+            "oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
+        },
+        prompt_messages=[UserPromptMessage(content="Hi")],
+        model_parameters={"temperature": 0.5, "max_tokens": 10},
+        stop=["How"],
+        stream=False,
+        user="abc-123",
+    )
+
+    assert isinstance(response, LLMResult)
+    assert len(response.message.content) > 0
+
+
+def test_invoke_stream_model():
+    model = OCILargeLanguageModel()
+
+    response = model.invoke(
+        model="meta.llama-3-70b-instruct",
+        credentials={
+            "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
+            "oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
+        },
+        prompt_messages=[UserPromptMessage(content="Hi")],
+        model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234},
+        stream=True,
+        user="abc-123",
+    )
+
+    assert isinstance(response, Generator)
+
+    for chunk in response:
+        assert isinstance(chunk, LLMResultChunk)
+        assert isinstance(chunk.delta, LLMResultChunkDelta)
+        assert isinstance(chunk.delta.message, AssistantPromptMessage)
+        assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
+
+
+def test_invoke_model_with_function():
+    model = OCILargeLanguageModel()
+
+    response = model.invoke(
+        model="cohere.command-r-plus",
+        credentials={
+            "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
+            "oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
+        },
+        prompt_messages=[UserPromptMessage(content="Hi")],
+        model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234},
+        stream=False,
+        user="abc-123",
+        tools=[
+            PromptMessageTool(
+                name="get_current_weather",
+                description="Get the current weather in a given location",
+                parameters={
+                    "type": "object",
+                    "properties": {
+                        "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
+                        "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
+                    },
+                    "required": ["location"],
+                },
+            )
+        ],
+    )
+
+    assert isinstance(response, LLMResult)
+    assert len(response.message.content) > 0
+
+
+def test_get_num_tokens():
+    model = OCILargeLanguageModel()
+
+    num_tokens = model.get_num_tokens(
+        model="cohere.command-r-plus",
+        credentials={
+            "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
+            "oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content="You are a helpful AI assistant.",
+            ),
+            UserPromptMessage(content="Hello World!"),
+        ],
+    )
+
+    assert num_tokens == 18

+ 20 - 0
api/tests/integration_tests/model_runtime/oci/test_provider.py

@@ -0,0 +1,20 @@
+import os
+
+import pytest
+
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.oci.oci import OCIGENAIProvider
+
+
+def test_validate_provider_credentials():
+    provider = OCIGENAIProvider()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        provider.validate_provider_credentials(credentials={})
+
+    provider.validate_provider_credentials(
+        credentials={
+            "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
+            "oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
+        }
+    )

+ 58 - 0
api/tests/integration_tests/model_runtime/oci/test_text_embedding.py

@@ -0,0 +1,58 @@
+import os
+
+import pytest
+
+from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.oci.text_embedding.text_embedding import OCITextEmbeddingModel
+
+
+def test_validate_credentials():
+    model = OCITextEmbeddingModel()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        model.validate_credentials(
+            model="cohere.embed-multilingual-v3.0",
+            credentials={"oci_config_content": "invalid_key", "oci_key_content": "invalid_key"},
+        )
+
+    model.validate_credentials(
+        model="cohere.embed-multilingual-v3.0",
+        credentials={
+            "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
+            "oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
+        },
+    )
+
+
+def test_invoke_model():
+    model = OCITextEmbeddingModel()
+
+    result = model.invoke(
+        model="cohere.embed-multilingual-v3.0",
+        credentials={
+            "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
+            "oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
+        },
+        texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)],
+        user="abc-123",
+    )
+
+    assert isinstance(result, TextEmbeddingResult)
+    assert len(result.embeddings) == 4
+    # assert result.usage.total_tokens == 811
+
+
+def test_get_num_tokens():
+    model = OCITextEmbeddingModel()
+
+    num_tokens = model.get_num_tokens(
+        model="cohere.embed-multilingual-v3.0",
+        credentials={
+            "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
+            "oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
+        },
+        texts=["hello", "world"],
+    )
+
+    assert num_tokens == 2

Some files were not shown because too many files changed in this diff