Ver Fonte

feat: add cohere llm and embedding (#2115)

takatost há 1 ano atrás
pai
commit
a18dde9b0d
27 ficheiros alterados com 1689 adições e 3 exclusões
  1. 5 0
      api/core/model_runtime/model_providers/__base/large_language_model.py
  2. 43 2
      api/core/model_runtime/model_providers/cohere/cohere.yaml
  3. 0 0
      api/core/model_runtime/model_providers/cohere/llm/__init__.py
  4. 8 0
      api/core/model_runtime/model_providers/cohere/llm/_position.yaml
  5. 62 0
      api/core/model_runtime/model_providers/cohere/llm/command-chat.yaml
  6. 62 0
      api/core/model_runtime/model_providers/cohere/llm/command-light-chat.yaml
  7. 62 0
      api/core/model_runtime/model_providers/cohere/llm/command-light-nightly-chat.yaml
  8. 44 0
      api/core/model_runtime/model_providers/cohere/llm/command-light-nightly.yaml
  9. 44 0
      api/core/model_runtime/model_providers/cohere/llm/command-light.yaml
  10. 62 0
      api/core/model_runtime/model_providers/cohere/llm/command-nightly-chat.yaml
  11. 44 0
      api/core/model_runtime/model_providers/cohere/llm/command-nightly.yaml
  12. 44 0
      api/core/model_runtime/model_providers/cohere/llm/command.yaml
  13. 565 0
      api/core/model_runtime/model_providers/cohere/llm/llm.py
  14. 0 0
      api/core/model_runtime/model_providers/cohere/text_embedding/__init__.py
  15. 7 0
      api/core/model_runtime/model_providers/cohere/text_embedding/_position.yaml
  16. 9 0
      api/core/model_runtime/model_providers/cohere/text_embedding/embed-english-light-v2.0.yaml
  17. 9 0
      api/core/model_runtime/model_providers/cohere/text_embedding/embed-english-light-v3.0.yaml
  18. 9 0
      api/core/model_runtime/model_providers/cohere/text_embedding/embed-english-v2.0.yaml
  19. 9 0
      api/core/model_runtime/model_providers/cohere/text_embedding/embed-english-v3.0.yaml
  20. 9 0
      api/core/model_runtime/model_providers/cohere/text_embedding/embed-multilingual-light-v3.0.yaml
  21. 9 0
      api/core/model_runtime/model_providers/cohere/text_embedding/embed-multilingual-v2.0.yaml
  22. 9 0
      api/core/model_runtime/model_providers/cohere/text_embedding/embed-multilingual-v3.0.yaml
  23. 234 0
      api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py
  24. 3 0
      api/core/spiltter/fixed_text_splitter.py
  25. 1 1
      api/requirements.txt
  26. 272 0
      api/tests/integration_tests/model_runtime/cohere/test_llm.py
  27. 64 0
      api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py

+ 5 - 0
api/core/model_runtime/model_providers/__base/large_language_model.py

@@ -1,5 +1,6 @@
 import logging
 import os
+import re
 import time
 from abc import abstractmethod
 from typing import Generator, List, Optional, Union
@@ -212,6 +213,10 @@ class LargeLanguageModel(AIModel):
         """
         raise NotImplementedError
 
+    def enforce_stop_tokens(self, text: str, stop: List[str]) -> str:
+        """Cut off the text as soon as any stop words occur."""
+        return re.split("|".join(stop), text, maxsplit=1)[0]
+
     def _llm_result_to_stream(self, result: LLMResult) -> Generator:
         """
         Transform llm result to stream

+ 43 - 2
api/core/model_runtime/model_providers/cohere/cohere.yaml

@@ -14,9 +14,12 @@ help:
   url:
     en_US: https://dashboard.cohere.com/api-keys
 supported_model_types:
+  - llm
+  - text-embedding
   - rerank
 configurate_methods:
   - predefined-model
+  - customizable-model
 provider_credential_schema:
   credential_form_schemas:
     - variable: api_key
@@ -26,6 +29,44 @@ provider_credential_schema:
       type: secret-input
       required: true
       placeholder:
-        zh_Hans: 请填写 API Key
-        en_US: Please fill in API Key
+        zh_Hans: 在此输入您的 API Key
+        en_US: Enter your API Key
       show_on: [ ]
+model_credential_schema:
+  model:
+    label:
+      en_US: Model Name
+      zh_Hans: 模型名称
+    placeholder:
+      en_US: Enter your model name
+      zh_Hans: 输入模型名称
+  credential_form_schemas:
+    - variable: mode
+      show_on:
+        - variable: __model_type
+          value: llm
+      label:
+        en_US: Completion mode
+      type: select
+      required: false
+      default: chat
+      placeholder:
+        zh_Hans: 选择对话类型
+        en_US: Select completion mode
+      options:
+        - value: completion
+          label:
+            en_US: Completion
+            zh_Hans: 补全
+        - value: chat
+          label:
+            en_US: Chat
+            zh_Hans: 对话
+    - variable: api_key
+      label:
+        en_US: API Key
+      type: secret-input
+      required: true
+      placeholder:
+        zh_Hans: 在此输入您的 API Key
+        en_US: Enter your API Key

+ 0 - 0
api/core/model_runtime/model_providers/cohere/llm/__init__.py


+ 8 - 0
api/core/model_runtime/model_providers/cohere/llm/_position.yaml

@@ -0,0 +1,8 @@
+- command-chat
+- command-light-chat
+- command-nightly-chat
+- command-light-nightly-chat
+- command
+- command-light
+- command-nightly
+- command-light-nightly

+ 62 - 0
api/core/model_runtime/model_providers/cohere/llm/command-chat.yaml

@@ -0,0 +1,62 @@
+model: command-chat
+label:
+  zh_Hans: command-chat
+  en_US: command-chat
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: chat
+  context_size: 4096
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+    max: 5.0
+  - name: p
+    use_template: top_p
+    default: 0.75
+    min: 0.01
+    max: 0.99
+  - name: k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+    required: false
+    default: 0
+    min: 0
+    max: 500
+  - name: max_tokens
+    use_template: max_tokens
+    default: 256
+    max: 4096
+  - name: preamble_override
+    label:
+      zh_Hans: 前导文本
+      en_US: Preamble
+    type: string
+    help:
+      zh_Hans: 当指定时,将使用提供的前导文本替换默认的 Cohere 前导文本。
+      en_US: When specified, the default Cohere preamble will be replaced with the provided one.
+    required: false
+  - name: prompt_truncation
+    label:
+      zh_Hans: 提示截断
+      en_US: Prompt Truncation
+    type: string
+    help:
+      zh_Hans: 指定如何构造 Prompt。当 prompt_truncation 设置为 "AUTO" 时,将会丢弃一些来自聊天记录的元素,以尝试构造一个符合模型上下文长度限制的 Prompt。
+      en_US: Dictates how the prompt will be constructed. With prompt_truncation set to "AUTO", some elements from chat histories will be dropped in an attempt to construct a prompt that fits within the model's context length limit.
+    required: true
+    default: 'AUTO'
+    options:
+      - 'AUTO'
+      - 'OFF'
+pricing:
+  input: '1.0'
+  output: '2.0'
+  unit: '0.000001'
+  currency: USD

+ 62 - 0
api/core/model_runtime/model_providers/cohere/llm/command-light-chat.yaml

@@ -0,0 +1,62 @@
+model: command-light-chat
+label:
+  zh_Hans: command-light-chat
+  en_US: command-light-chat
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: chat
+  context_size: 4096
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+    max: 5.0
+  - name: p
+    use_template: top_p
+    default: 0.75
+    min: 0.01
+    max: 0.99
+  - name: k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+    required: false
+    default: 0
+    min: 0
+    max: 500
+  - name: max_tokens
+    use_template: max_tokens
+    default: 256
+    max: 4096
+  - name: preamble_override
+    label:
+      zh_Hans: 前导文本
+      en_US: Preamble
+    type: string
+    help:
+      zh_Hans: 当指定时,将使用提供的前导文本替换默认的 Cohere 前导文本。
+      en_US: When specified, the default Cohere preamble will be replaced with the provided one.
+    required: false
+  - name: prompt_truncation
+    label:
+      zh_Hans: 提示截断
+      en_US: Prompt Truncation
+    type: string
+    help:
+      zh_Hans: 指定如何构造 Prompt。当 prompt_truncation 设置为 "AUTO" 时,将会丢弃一些来自聊天记录的元素,以尝试构造一个符合模型上下文长度限制的 Prompt。
+      en_US: Dictates how the prompt will be constructed. With prompt_truncation set to "AUTO", some elements from chat histories will be dropped in an attempt to construct a prompt that fits within the model's context length limit.
+    required: true
+    default: 'AUTO'
+    options:
+      - 'AUTO'
+      - 'OFF'
+pricing:
+  input: '0.3'
+  output: '0.6'
+  unit: '0.000001'
+  currency: USD

+ 62 - 0
api/core/model_runtime/model_providers/cohere/llm/command-light-nightly-chat.yaml

@@ -0,0 +1,62 @@
+model: command-light-nightly-chat
+label:
+  zh_Hans: command-light-nightly-chat
+  en_US: command-light-nightly-chat
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: chat
+  context_size: 4096
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+    max: 5.0
+  - name: p
+    use_template: top_p
+    default: 0.75
+    min: 0.01
+    max: 0.99
+  - name: k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+    required: false
+    default: 0
+    min: 0
+    max: 500
+  - name: max_tokens
+    use_template: max_tokens
+    default: 256
+    max: 4096
+  - name: preamble_override
+    label:
+      zh_Hans: 前导文本
+      en_US: Preamble
+    type: string
+    help:
+      zh_Hans: 当指定时,将使用提供的前导文本替换默认的 Cohere 前导文本。
+      en_US: When specified, the default Cohere preamble will be replaced with the provided one.
+    required: false
+  - name: prompt_truncation
+    label:
+      zh_Hans: 提示截断
+      en_US: Prompt Truncation
+    type: string
+    help:
+      zh_Hans: 指定如何构造 Prompt。当 prompt_truncation 设置为 "AUTO" 时,将会丢弃一些来自聊天记录的元素,以尝试构造一个符合模型上下文长度限制的 Prompt。
+      en_US: Dictates how the prompt will be constructed. With prompt_truncation set to "AUTO", some elements from chat histories will be dropped in an attempt to construct a prompt that fits within the model's context length limit.
+    required: true
+    default: 'AUTO'
+    options:
+      - 'AUTO'
+      - 'OFF'
+pricing:
+  input: '0.3'
+  output: '0.6'
+  unit: '0.000001'
+  currency: USD

+ 44 - 0
api/core/model_runtime/model_providers/cohere/llm/command-light-nightly.yaml

@@ -0,0 +1,44 @@
+model: command-light-nightly
+label:
+  zh_Hans: command-light-nightly
+  en_US: command-light-nightly
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: completion
+  context_size: 4096
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+    max: 5.0
+  - name: p
+    use_template: top_p
+    default: 0.75
+    min: 0.01
+    max: 0.99
+  - name: k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+    required: false
+    default: 0
+    min: 0
+    max: 500
+  - name: presence_penalty
+    use_template: presence_penalty
+  - name: frequency_penalty
+    use_template: frequency_penalty
+  - name: max_tokens
+    use_template: max_tokens
+    default: 256
+    max: 4096
+pricing:
+  input: '0.3'
+  output: '0.6'
+  unit: '0.000001'
+  currency: USD

+ 44 - 0
api/core/model_runtime/model_providers/cohere/llm/command-light.yaml

@@ -0,0 +1,44 @@
+model: command-light
+label:
+  zh_Hans: command-light
+  en_US: command-light
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: completion
+  context_size: 4096
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+    max: 5.0
+  - name: p
+    use_template: top_p
+    default: 0.75
+    min: 0.01
+    max: 0.99
+  - name: k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+    required: false
+    default: 0
+    min: 0
+    max: 500
+  - name: presence_penalty
+    use_template: presence_penalty
+  - name: frequency_penalty
+    use_template: frequency_penalty
+  - name: max_tokens
+    use_template: max_tokens
+    default: 256
+    max: 4096
+pricing:
+  input: '0.3'
+  output: '0.6'
+  unit: '0.000001'
+  currency: USD

+ 62 - 0
api/core/model_runtime/model_providers/cohere/llm/command-nightly-chat.yaml

@@ -0,0 +1,62 @@
+model: command-nightly-chat
+label:
+  zh_Hans: command-nightly-chat
+  en_US: command-nightly-chat
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: chat
+  context_size: 4096
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+    max: 5.0
+  - name: p
+    use_template: top_p
+    default: 0.75
+    min: 0.01
+    max: 0.99
+  - name: k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+    required: false
+    default: 0
+    min: 0
+    max: 500
+  - name: max_tokens
+    use_template: max_tokens
+    default: 256
+    max: 4096
+  - name: preamble_override
+    label:
+      zh_Hans: 前导文本
+      en_US: Preamble
+    type: string
+    help:
+      zh_Hans: 当指定时,将使用提供的前导文本替换默认的 Cohere 前导文本。
+      en_US: When specified, the default Cohere preamble will be replaced with the provided one.
+    required: false
+  - name: prompt_truncation
+    label:
+      zh_Hans: 提示截断
+      en_US: Prompt Truncation
+    type: string
+    help:
+      zh_Hans: 指定如何构造 Prompt。当 prompt_truncation 设置为 "AUTO" 时,将会丢弃一些来自聊天记录的元素,以尝试构造一个符合模型上下文长度限制的 Prompt。
+      en_US: Dictates how the prompt will be constructed. With prompt_truncation set to "AUTO", some elements from chat histories will be dropped in an attempt to construct a prompt that fits within the model's context length limit.
+    required: true
+    default: 'AUTO'
+    options:
+      - 'AUTO'
+      - 'OFF'
+pricing:
+  input: '1.0'
+  output: '2.0'
+  unit: '0.000001'
+  currency: USD

+ 44 - 0
api/core/model_runtime/model_providers/cohere/llm/command-nightly.yaml

@@ -0,0 +1,44 @@
+model: command-nightly
+label:
+  zh_Hans: command-nightly
+  en_US: command-nightly
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: completion
+  context_size: 4096
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+    max: 5.0
+  - name: p
+    use_template: top_p
+    default: 0.75
+    min: 0.01
+    max: 0.99
+  - name: k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+    required: false
+    default: 0
+    min: 0
+    max: 500
+  - name: presence_penalty
+    use_template: presence_penalty
+  - name: frequency_penalty
+    use_template: frequency_penalty
+  - name: max_tokens
+    use_template: max_tokens
+    default: 256
+    max: 4096
+pricing:
+  input: '1.0'
+  output: '2.0'
+  unit: '0.000001'
+  currency: USD

+ 44 - 0
api/core/model_runtime/model_providers/cohere/llm/command.yaml

@@ -0,0 +1,44 @@
+model: command
+label:
+  zh_Hans: command
+  en_US: command
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: completion
+  context_size: 4096
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+    max: 5.0
+  - name: p
+    use_template: top_p
+    default: 0.75
+    min: 0.01
+    max: 0.99
+  - name: k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+    required: false
+    default: 0
+    min: 0
+    max: 500
+  - name: presence_penalty
+    use_template: presence_penalty
+  - name: frequency_penalty
+    use_template: frequency_penalty
+  - name: max_tokens
+    use_template: max_tokens
+    default: 256
+    max: 4096
+pricing:
+  input: '1.0'
+  output: '2.0'
+  unit: '0.000001'
+  currency: USD

+ 565 - 0
api/core/model_runtime/model_providers/cohere/llm/llm.py

@@ -0,0 +1,565 @@
+import logging
+from typing import Generator, List, Optional, Union, cast, Tuple
+
+import cohere
+from cohere.responses import Chat, Generations
+from cohere.responses.chat import StreamingChat, StreamTextGeneration, StreamEnd
+from cohere.responses.generation import StreamingText, StreamingGenerations
+
+from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage,
+                                                          PromptMessageContentType, SystemPromptMessage,
+                                                          TextPromptMessageContent, UserPromptMessage,
+                                                          PromptMessageTool)
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
+from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeError, \
+    InvokeRateLimitError, InvokeAuthorizationError, InvokeBadRequestError
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+
+logger = logging.getLogger(__name__)
+
+
+class CohereLargeLanguageModel(LargeLanguageModel):
+    """
+    Model class for Cohere large language model.
+    """
+
+    def _invoke(self, model: str, credentials: dict,
+                prompt_messages: list[PromptMessage], model_parameters: dict,
+                tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
+                stream: bool = True, user: Optional[str] = None) \
+            -> Union[LLMResult, Generator]:
+        """
+        Invoke large language model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param prompt_messages: prompt messages
+        :param model_parameters: model parameters
+        :param tools: tools for tool calling
+        :param stop: stop words
+        :param stream: is stream response
+        :param user: unique user id
+        :return: full response or stream response chunk generator result
+        """
+        # get model mode
+        model_mode = self.get_model_mode(model, credentials)
+
+        if model_mode == LLMMode.CHAT:
+            return self._chat_generate(
+                model=model,
+                credentials=credentials,
+                prompt_messages=prompt_messages,
+                model_parameters=model_parameters,
+                stop=stop,
+                stream=stream,
+                user=user
+            )
+        else:
+            return self._generate(
+                model=model,
+                credentials=credentials,
+                prompt_messages=prompt_messages,
+                model_parameters=model_parameters,
+                stop=stop,
+                stream=stream,
+                user=user
+            )
+
+    def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+                       tools: Optional[list[PromptMessageTool]] = None) -> int:
+        """
+        Get number of tokens for given prompt messages
+
+        :param model: model name
+        :param credentials: model credentials
+        :param prompt_messages: prompt messages
+        :param tools: tools for tool calling
+        :return:
+        """
+        # get model mode
+        model_mode = self.get_model_mode(model)
+
+        try:
+            if model_mode == LLMMode.CHAT:
+                return self._num_tokens_from_messages(model, credentials, prompt_messages)
+            else:
+                return self._num_tokens_from_string(model, credentials, prompt_messages[0].content)
+        except Exception as e:
+            raise self._transform_invoke_error(e)
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate model credentials
+
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        try:
+            # get model mode
+            model_mode = self.get_model_mode(model)
+
+            if model_mode == LLMMode.CHAT:
+                self._chat_generate(
+                    model=model,
+                    credentials=credentials,
+                    prompt_messages=[UserPromptMessage(content='ping')],
+                    model_parameters={
+                        'max_tokens': 20,
+                        'temperature': 0,
+                    },
+                    stream=False
+                )
+            else:
+                self._generate(
+                    model=model,
+                    credentials=credentials,
+                    prompt_messages=[UserPromptMessage(content='ping')],
+                    model_parameters={
+                        'max_tokens': 20,
+                        'temperature': 0,
+                    },
+                    stream=False
+                )
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))
+
+    def _generate(self, model: str, credentials: dict,
+                  prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None,
+                  stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
+        """
+        Invoke llm model
+
+        :param model: model name
+        :param credentials: credentials
+        :param prompt_messages: prompt messages
+        :param model_parameters: model parameters
+        :param stop: stop words
+        :param stream: is stream response
+        :param user: unique user id
+        :return: full response or stream response chunk generator result
+        """
+        # initialize client
+        client = cohere.Client(credentials.get('api_key'))
+
+        if stop:
+            model_parameters['end_sequences'] = stop
+
+        response = client.generate(
+            prompt=prompt_messages[0].content,
+            model=model,
+            stream=stream,
+            **model_parameters,
+        )
+
+        if stream:
+            return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
+
+        return self._handle_generate_response(model, credentials, response, prompt_messages)
+
+    def _handle_generate_response(self, model: str, credentials: dict, response: Generations,
+                                  prompt_messages: list[PromptMessage]) \
+            -> LLMResult:
+        """
+        Handle llm response
+
+        :param model: model name
+        :param credentials: credentials
+        :param response: response
+        :param prompt_messages: prompt messages
+        :return: llm response
+        """
+        assistant_text = response.generations[0].text
+
+        # transform assistant message to prompt message
+        assistant_prompt_message = AssistantPromptMessage(
+            content=assistant_text
+        )
+
+        # calculate num tokens
+        prompt_tokens = response.meta['billed_units']['input_tokens']
+        completion_tokens = response.meta['billed_units']['output_tokens']
+
+        # transform usage
+        usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+        # transform response
+        response = LLMResult(
+            model=model,
+            prompt_messages=prompt_messages,
+            message=assistant_prompt_message,
+            usage=usage
+        )
+
+        return response
+
+    def _handle_generate_stream_response(self, model: str, credentials: dict, response: StreamingGenerations,
+                                         prompt_messages: list[PromptMessage]) -> Generator:
+        """
+        Handle llm stream response
+
+        :param model: model name
+        :param response: response
+        :param prompt_messages: prompt messages
+        :return: llm response chunk generator
+        """
+        index = 1
+        full_assistant_content = ''
+        for chunk in response:
+            if isinstance(chunk, StreamingText):
+                chunk = cast(StreamingText, chunk)
+                text = chunk.text
+
+                if text is None:
+                    continue
+
+                # transform assistant message to prompt message
+                assistant_prompt_message = AssistantPromptMessage(
+                    content=text
+                )
+
+                full_assistant_content += text
+
+                yield LLMResultChunk(
+                    model=model,
+                    prompt_messages=prompt_messages,
+                    delta=LLMResultChunkDelta(
+                        index=index,
+                        message=assistant_prompt_message,
+                    )
+                )
+
+                index += 1
+            elif chunk is None:
+                # calculate num tokens
+                prompt_tokens = response.meta['billed_units']['input_tokens']
+                completion_tokens = response.meta['billed_units']['output_tokens']
+
+                # transform usage
+                usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+                yield LLMResultChunk(
+                    model=model,
+                    prompt_messages=prompt_messages,
+                    delta=LLMResultChunkDelta(
+                        index=index,
+                        message=AssistantPromptMessage(content=''),
+                        finish_reason=response.finish_reason,
+                        usage=usage
+                    )
+                )
+                break
+
+    def _chat_generate(self, model: str, credentials: dict,
+                       prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None,
+                       stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
+        """
+        Invoke llm chat model
+
+        :param model: model name
+        :param credentials: credentials
+        :param prompt_messages: prompt messages
+        :param model_parameters: model parameters
+        :param stop: stop words
+        :param stream: is stream response
+        :param user: unique user id
+        :return: full response or stream response chunk generator result
+        """
+        # initialize client
+        client = cohere.Client(credentials.get('api_key'))
+
+        if user:
+            model_parameters['user_name'] = user
+
+        message, chat_histories = self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages)
+
+        # chat model
+        real_model = model
+        if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL:
+            real_model = model.removesuffix('-chat')
+
+        response = client.chat(
+            message=message,
+            chat_history=chat_histories,
+            model=real_model,
+            stream=stream,
+            return_preamble=True,
+            **model_parameters,
+        )
+
+        if stream:
+            return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, stop)
+
+        return self._handle_chat_generate_response(model, credentials, response, prompt_messages, stop)
+
+    def _handle_chat_generate_response(self, model: str, credentials: dict, response: Chat,
+                                       prompt_messages: list[PromptMessage], stop: Optional[List[str]] = None) \
+            -> LLMResult:
+        """
+        Handle llm chat response
+
+        :param model: model name
+        :param credentials: credentials
+        :param response: response
+        :param prompt_messages: prompt messages
+        :param stop: stop words
+        :return: llm response
+        """
+        assistant_text = response.text
+
+        # transform assistant message to prompt message
+        assistant_prompt_message = AssistantPromptMessage(
+            content=assistant_text
+        )
+
+        # calculate num tokens
+        prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
+        completion_tokens = self._num_tokens_from_messages(model, credentials, [assistant_prompt_message])
+
+        # transform usage
+        usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+        if stop:
+            # enforce stop tokens
+            assistant_text = self.enforce_stop_tokens(assistant_text, stop)
+            assistant_prompt_message = AssistantPromptMessage(
+                content=assistant_text
+            )
+
+        # transform response
+        response = LLMResult(
+            model=model,
+            prompt_messages=prompt_messages,
+            message=assistant_prompt_message,
+            usage=usage,
+            system_fingerprint=response.preamble
+        )
+
+        return response
+
+    def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: StreamingChat,
+                                              prompt_messages: list[PromptMessage],
+                                              stop: Optional[List[str]] = None) -> Generator:
+        """
+        Handle llm chat stream response
+
+        :param model: model name
+        :param response: response
+        :param prompt_messages: prompt messages
+        :param stop: stop words
+        :return: llm response chunk generator
+        """
+
+        def final_response(full_text: str, index: int, finish_reason: Optional[str] = None,
+                           preamble: Optional[str] = None) -> LLMResultChunk:
+            # calculate num tokens
+            prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
+
+            full_assistant_prompt_message = AssistantPromptMessage(
+                content=full_text
+            )
+            completion_tokens = self._num_tokens_from_messages(model, credentials, [full_assistant_prompt_message])
+
+            # transform usage
+            usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+            return LLMResultChunk(
+                model=model,
+                prompt_messages=prompt_messages,
+                system_fingerprint=preamble,
+                delta=LLMResultChunkDelta(
+                    index=index,
+                    message=AssistantPromptMessage(content=''),
+                    finish_reason=finish_reason,
+                    usage=usage
+                )
+            )
+
+        index = 1
+        full_assistant_content = ''
+        for chunk in response:
+            if isinstance(chunk, StreamTextGeneration):
+                chunk = cast(StreamTextGeneration, chunk)
+                text = chunk.text
+
+                if text is None:
+                    continue
+
+                # transform assistant message to prompt message
+                assistant_prompt_message = AssistantPromptMessage(
+                    content=text
+                )
+
+                # stop
+                # notice: This logic can only cover few stop scenarios
+                if stop and text in stop:
+                    yield final_response(full_assistant_content, index, 'stop')
+                    break
+
+                full_assistant_content += text
+
+                yield LLMResultChunk(
+                    model=model,
+                    prompt_messages=prompt_messages,
+                    delta=LLMResultChunkDelta(
+                        index=index,
+                        message=assistant_prompt_message,
+                    )
+                )
+
+                index += 1
+            elif isinstance(chunk, StreamEnd):
+                chunk = cast(StreamEnd, chunk)
+                yield final_response(full_assistant_content, index, chunk.finish_reason, response.preamble)
+                index += 1
+
+    def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages: list[PromptMessage]) \
+            -> Tuple[str, list[dict]]:
+        """
+        Convert prompt messages to message and chat histories
+        :param prompt_messages: prompt messages
+        :return:
+        """
+        chat_histories = []
+        for prompt_message in prompt_messages:
+            chat_histories.append(self._convert_prompt_message_to_dict(prompt_message))
+
+        # get latest message from chat histories and pop it
+        if len(chat_histories) > 0:
+            latest_message = chat_histories.pop()
+            message = latest_message['message']
+        else:
+            raise ValueError('Prompt messages is empty')
+
+        return message, chat_histories
+
+    def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
+        """
+        Convert PromptMessage to dict for Cohere model
+        """
+        if isinstance(message, UserPromptMessage):
+            message = cast(UserPromptMessage, message)
+            if isinstance(message.content, str):
+                message_dict = {"role": "USER", "message": message.content}
+            else:
+                sub_message_text = ''
+                for message_content in message.content:
+                    if message_content.type == PromptMessageContentType.TEXT:
+                        message_content = cast(TextPromptMessageContent, message_content)
+                        sub_message_text += message_content.data
+
+                message_dict = {"role": "USER", "message": sub_message_text}
+        elif isinstance(message, AssistantPromptMessage):
+            message = cast(AssistantPromptMessage, message)
+            message_dict = {"role": "CHATBOT", "message": message.content}
+        elif isinstance(message, SystemPromptMessage):
+            message = cast(SystemPromptMessage, message)
+            message_dict = {"role": "USER", "message": message.content}
+        else:
+            raise ValueError(f"Got unknown type {message}")
+
+        if message.name is not None:
+            message_dict["user_name"] = message.name
+
+        return message_dict
+
+    def _num_tokens_from_string(self, model: str, credentials: dict, text: str) -> int:
+        """
+        Calculate num tokens for text completion model.
+
+        :param model: model name
+        :param credentials: credentials
+        :param text: prompt text
+        :return: number of tokens
+        """
+        # initialize client
+        client = cohere.Client(credentials.get('api_key'))
+
+        response = client.tokenize(
+            text=text,
+            model=model
+        )
+
+        return response.length
+
+    def _num_tokens_from_messages(self, model: str, credentials: dict, messages: List[PromptMessage]) -> int:
+        """Calculate num tokens Cohere model."""
+        messages = [self._convert_prompt_message_to_dict(m) for m in messages]
+        message_strs = [f"{message['role']}: {message['message']}" for message in messages]
+        message_str = "\n".join(message_strs)
+
+        real_model = model
+        if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL:
+            real_model = model.removesuffix('-chat')
+
+        return self._num_tokens_from_string(real_model, credentials, message_str)
+
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
+        """
+            Cohere supports fine-tuning of their models. This method returns the schema of the base model
+            but renamed to the fine-tuned model name.
+
+            :param model: model name
+            :param credentials: credentials
+
+            :return: model schema
+        """
+        # get model schema
+        models = self.predefined_models()
+        model_map = {model.model: model for model in models}
+
+        mode = credentials.get('mode')
+
+        if mode == 'chat':
+            base_model_schema = model_map['command-light-chat']
+        else:
+            base_model_schema = model_map['command-light']
+
+        base_model_schema = cast(AIModelEntity, base_model_schema)
+
+        base_model_schema_features = base_model_schema.features or []
+        base_model_schema_model_properties = base_model_schema.model_properties or {}
+        base_model_schema_parameters_rules = base_model_schema.parameter_rules or []
+
+        entity = AIModelEntity(
+            model=model,
+            label=I18nObject(
+                zh_Hans=model,
+                en_US=model
+            ),
+            model_type=ModelType.LLM,
+            features=[feature for feature in base_model_schema_features],
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_properties={
+                key: property for key, property in base_model_schema_model_properties.items()
+            },
+            parameter_rules=[rule for rule in base_model_schema_parameters_rules],
+            pricing=base_model_schema.pricing
+        )
+
+        return entity
+
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        """
+        Map model invoke error to unified error
+        The key is the error type thrown to the caller
+        The value is the error type thrown by the model,
+        which needs to be converted into a unified error type for the caller.
+
+        :return: Invoke error mapping
+        """
+        return {
+            InvokeConnectionError: [
+                cohere.CohereConnectionError
+            ],
+            InvokeServerUnavailableError: [],
+            InvokeRateLimitError: [],
+            InvokeAuthorizationError: [],
+            InvokeBadRequestError: [
+                cohere.CohereAPIError,
+                cohere.CohereError,
+            ]
+        }

+ 0 - 0
api/core/model_runtime/model_providers/cohere/text_embedding/__init__.py


+ 7 - 0
api/core/model_runtime/model_providers/cohere/text_embedding/_position.yaml

@@ -0,0 +1,7 @@
+- embed-multilingual-v3.0
+- embed-multilingual-light-v3.0
+- embed-english-v3.0
+- embed-english-light-v3.0
+- embed-multilingual-v2.0
+- embed-english-v2.0
+- embed-english-light-v2.0

+ 9 - 0
api/core/model_runtime/model_providers/cohere/text_embedding/embed-english-light-v2.0.yaml

@@ -0,0 +1,9 @@
+model: embed-english-light-v2.0
+model_type: text-embedding
+model_properties:
+  context_size: 1024
+  max_chunks: 48
+pricing:
+  input: '0.1'
+  unit: '0.000001'
+  currency: USD

+ 9 - 0
api/core/model_runtime/model_providers/cohere/text_embedding/embed-english-light-v3.0.yaml

@@ -0,0 +1,9 @@
+model: embed-english-light-v3.0
+model_type: text-embedding
+model_properties:
+  context_size: 384
+  max_chunks: 48
+pricing:
+  input: '0.1'
+  unit: '0.000001'
+  currency: USD

+ 9 - 0
api/core/model_runtime/model_providers/cohere/text_embedding/embed-english-v2.0.yaml

@@ -0,0 +1,9 @@
+model: embed-english-v2.0
+model_type: text-embedding
+model_properties:
+  context_size: 4096
+  max_chunks: 48
+pricing:
+  input: '0.1'
+  unit: '0.000001'
+  currency: USD

+ 9 - 0
api/core/model_runtime/model_providers/cohere/text_embedding/embed-english-v3.0.yaml

@@ -0,0 +1,9 @@
+model: embed-english-v3.0
+model_type: text-embedding
+model_properties:
+  context_size: 1024
+  max_chunks: 48
+pricing:
+  input: '0.1'
+  unit: '0.000001'
+  currency: USD

+ 9 - 0
api/core/model_runtime/model_providers/cohere/text_embedding/embed-multilingual-light-v3.0.yaml

@@ -0,0 +1,9 @@
+model: embed-multilingual-light-v3.0
+model_type: text-embedding
+model_properties:
+  context_size: 384
+  max_chunks: 48
+pricing:
+  input: '0.1'
+  unit: '0.000001'
+  currency: USD

+ 9 - 0
api/core/model_runtime/model_providers/cohere/text_embedding/embed-multilingual-v2.0.yaml

@@ -0,0 +1,9 @@
+model: embed-multilingual-v2.0
+model_type: text-embedding
+model_properties:
+  context_size: 768
+  max_chunks: 48
+pricing:
+  input: '0.1'
+  unit: '0.000001'
+  currency: USD

+ 9 - 0
api/core/model_runtime/model_providers/cohere/text_embedding/embed-multilingual-v3.0.yaml

@@ -0,0 +1,9 @@
+model: embed-multilingual-v3.0
+model_type: text-embedding
+model_properties:
+  context_size: 1024
+  max_chunks: 48
+pricing:
+  input: '0.1'
+  unit: '0.000001'
+  currency: USD

+ 234 - 0
api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py

@@ -0,0 +1,234 @@
+import time
+from typing import Optional, Tuple
+
+import cohere
+import numpy as np
+from cohere.responses import Tokens
+
+from core.model_runtime.entities.model_entities import PriceType
+from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
+from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \
+    InvokeAuthorizationError, InvokeBadRequestError, InvokeError
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+
+
+class CohereTextEmbeddingModel(TextEmbeddingModel):
+    """
+    Model class for Cohere text embedding model.
+    """
+
+    def _invoke(self, model: str, credentials: dict,
+                texts: list[str], user: Optional[str] = None) \
+            -> TextEmbeddingResult:
+        """
+        Invoke text embedding model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param texts: texts to embed
+        :param user: unique user id
+        :return: embeddings result
+        """
+        # get model properties
+        context_size = self._get_context_size(model, credentials)
+        max_chunks = self._get_max_chunks(model, credentials)
+
+        embeddings: list[list[float]] = [[] for _ in range(len(texts))]
+        tokens = []
+        indices = []
+        used_tokens = 0
+
+        for i, text in enumerate(texts):
+            tokenize_response = self._tokenize(
+                model=model,
+                credentials=credentials,
+                text=text
+            )
+
+            for j in range(0, tokenize_response.length, context_size):
+                tokens += [tokenize_response.token_strings[j: j + context_size]]
+                indices += [i]
+
+        batched_embeddings = []
+        _iter = range(0, len(tokens), max_chunks)
+
+        for i in _iter:
+            # call embedding model
+            embeddings_batch, embedding_used_tokens = self._embedding_invoke(
+                model=model,
+                credentials=credentials,
+                texts=["".join(token) for token in tokens[i: i + max_chunks]]
+            )
+
+            used_tokens += embedding_used_tokens
+            batched_embeddings += embeddings_batch
+
+        results: list[list[list[float]]] = [[] for _ in range(len(texts))]
+        num_tokens_in_batch: list[list[int]] = [[] for _ in range(len(texts))]
+        for i in range(len(indices)):
+            results[indices[i]].append(batched_embeddings[i])
+            num_tokens_in_batch[indices[i]].append(len(tokens[i]))
+
+        for i in range(len(texts)):
+            _result = results[i]
+            if len(_result) == 0:
+                embeddings_batch, embedding_used_tokens = self._embedding_invoke(
+                    model=model,
+                    credentials=credentials,
+                    texts=[""]
+                )
+
+                used_tokens += embedding_used_tokens
+                average = embeddings_batch[0]
+            else:
+                average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
+            embeddings[i] = (average / np.linalg.norm(average)).tolist()
+
+        # calc usage
+        usage = self._calc_response_usage(
+            model=model,
+            credentials=credentials,
+            tokens=used_tokens
+        )
+
+        return TextEmbeddingResult(
+            embeddings=embeddings,
+            usage=usage,
+            model=model
+        )
+
+    def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+        """
+        Get number of tokens for given prompt messages
+
+        :param model: model name
+        :param credentials: model credentials
+        :param texts: texts to embed
+        :return:
+        """
+        if len(texts) == 0:
+            return 0
+
+        full_text = ' '.join(texts)
+
+        try:
+            response = self._tokenize(
+                model=model,
+                credentials=credentials,
+                text=full_text
+            )
+        except Exception as e:
+            raise self._transform_invoke_error(e)
+
+        return response.length
+
+    def _tokenize(self, model: str, credentials: dict, text: str) -> Tokens:
+        """
+        Tokenize text
+        :param model: model name
+        :param credentials: model credentials
+        :param text: text to tokenize
+        :return:
+        """
+        # initialize client
+        client = cohere.Client(credentials.get('api_key'))
+
+        response = client.tokenize(
+            text=text,
+            model=model
+        )
+
+        return response
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate model credentials
+
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        try:
+            # call embedding model
+            self._embedding_invoke(
+                model=model,
+                credentials=credentials,
+                texts=['ping']
+            )
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))
+
+    def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) -> Tuple[list[list[float]], int]:
+        """
+        Invoke embedding model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param texts: texts to embed
+        :return: embeddings and used tokens
+        """
+        # initialize client
+        client = cohere.Client(credentials.get('api_key'))
+
+        # call embedding model
+        response = client.embed(
+            texts=texts,
+            model=model,
+            input_type='search_document' if len(texts) > 1 else 'search_query'
+        )
+
+        return response.embeddings, response.meta['billed_units']['input_tokens']
+
+    def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
+        """
+        Calculate response usage
+
+        :param model: model name
+        :param credentials: model credentials
+        :param tokens: input tokens
+        :return: usage
+        """
+        # get input price info
+        input_price_info = self.get_price(
+            model=model,
+            credentials=credentials,
+            price_type=PriceType.INPUT,
+            tokens=tokens
+        )
+
+        # transform usage
+        usage = EmbeddingUsage(
+            tokens=tokens,
+            total_tokens=tokens,
+            unit_price=input_price_info.unit_price,
+            price_unit=input_price_info.unit,
+            total_price=input_price_info.total_amount,
+            currency=input_price_info.currency,
+            latency=time.perf_counter() - self.started_at
+        )
+
+        return usage
+
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        """
+        Map model invoke error to unified error
+        The key is the error type thrown to the caller
+        The value is the error type thrown by the model,
+        which needs to be converted into a unified error type for the caller.
+
+        :return: Invoke error mapping
+        """
+        return {
+            InvokeConnectionError: [
+                cohere.CohereConnectionError
+            ],
+            InvokeServerUnavailableError: [],
+            InvokeRateLimitError: [],
+            InvokeAuthorizationError: [],
+            InvokeBadRequestError: [
+                cohere.CohereAPIError,
+                cohere.CohereError,
+            ]
+        }

+ 3 - 0
api/core/spiltter/fixed_text_splitter.py

@@ -24,6 +24,9 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
             **kwargs: Any,
     ):
         def _token_encoder(text: str) -> int:
+            if not text:
+                return 0
+
             if embedding_model_instance:
                 embedding_model_type_instance = embedding_model_instance.model_type_instance
                 embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)

+ 1 - 1
api/requirements.txt

@@ -54,7 +54,7 @@ zhipuai==1.0.7
 werkzeug==2.3.8
 pymilvus==2.3.0
 qdrant-client==1.6.4
-cohere~=4.32
+cohere~=4.44
 pyyaml~=6.0.1
 numpy~=1.25.2
 unstructured[docx,pptx,msg,md,ppt]~=0.10.27

+ 272 - 0
api/tests/integration_tests/model_runtime/cohere/test_llm.py

@@ -0,0 +1,272 @@
+import os
+from typing import Generator
+
+import pytest
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.message_entities import (AssistantPromptMessage, SystemPromptMessage,
+                                                          UserPromptMessage)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.cohere.llm.llm import CohereLargeLanguageModel
+
+
+def test_validate_credentials_for_chat_model():
+    model = CohereLargeLanguageModel()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        model.validate_credentials(
+            model='command-light-chat',
+            credentials={
+                'api_key': 'invalid_key'
+            }
+        )
+
+    model.validate_credentials(
+        model='command-light-chat',
+        credentials={
+            'api_key': os.environ.get('COHERE_API_KEY')
+        }
+    )
+
+
+def test_validate_credentials_for_completion_model():
+    model = CohereLargeLanguageModel()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        model.validate_credentials(
+            model='command-light',
+            credentials={
+                'api_key': 'invalid_key'
+            }
+        )
+
+    model.validate_credentials(
+        model='command-light',
+        credentials={
+            'api_key': os.environ.get('COHERE_API_KEY')
+        }
+    )
+
+
+def test_invoke_completion_model():
+    model = CohereLargeLanguageModel()
+
+    credentials = {
+        'api_key': os.environ.get('COHERE_API_KEY')
+    }
+
+    result = model.invoke(
+        model='command-light',
+        credentials=credentials,
+        prompt_messages=[
+            UserPromptMessage(
+                content='Hello World!'
+            )
+        ],
+        model_parameters={
+            'temperature': 0.0,
+            'max_tokens': 1
+        },
+        stream=False,
+        user="abc-123"
+    )
+
+    assert isinstance(result, LLMResult)
+    assert len(result.message.content) > 0
+    assert model._num_tokens_from_string('command-light', credentials, result.message.content) == 1
+
+
+def test_invoke_stream_completion_model():
+    model = CohereLargeLanguageModel()
+
+    result = model.invoke(
+        model='command-light',
+        credentials={
+            'api_key': os.environ.get('COHERE_API_KEY')
+        },
+        prompt_messages=[
+            UserPromptMessage(
+                content='Hello World!'
+            )
+        ],
+        model_parameters={
+            'temperature': 0.0,
+            'max_tokens': 100
+        },
+        stream=True,
+        user="abc-123"
+    )
+
+    assert isinstance(result, Generator)
+
+    for chunk in result:
+        assert isinstance(chunk, LLMResultChunk)
+        assert isinstance(chunk.delta, LLMResultChunkDelta)
+        assert isinstance(chunk.delta.message, AssistantPromptMessage)
+        assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
+
+
+def test_invoke_chat_model():
+    model = CohereLargeLanguageModel()
+
+    result = model.invoke(
+        model='command-light-chat',
+        credentials={
+            'api_key': os.environ.get('COHERE_API_KEY')
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content='You are a helpful AI assistant.',
+            ),
+            UserPromptMessage(
+                content='Hello World!'
+            )
+        ],
+        model_parameters={
+            'temperature': 0.0,
+            'p': 0.99,
+            'presence_penalty': 0.0,
+            'frequency_penalty': 0.0,
+            'max_tokens': 10
+        },
+        stop=['How'],
+        stream=False,
+        user="abc-123"
+    )
+
+    assert isinstance(result, LLMResult)
+    assert len(result.message.content) > 0
+
+    for chunk in model._llm_result_to_stream(result):
+        assert isinstance(chunk, LLMResultChunk)
+        assert isinstance(chunk.delta, LLMResultChunkDelta)
+        assert isinstance(chunk.delta.message, AssistantPromptMessage)
+        assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
+
+
+def test_invoke_stream_chat_model():
+    model = CohereLargeLanguageModel()
+
+    result = model.invoke(
+        model='command-light-chat',
+        credentials={
+            'api_key': os.environ.get('COHERE_API_KEY')
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content='You are a helpful AI assistant.',
+            ),
+            UserPromptMessage(
+                content='Hello World!'
+            )
+        ],
+        model_parameters={
+            'temperature': 0.0,
+            'max_tokens': 100
+        },
+        stream=True,
+        user="abc-123"
+    )
+
+    assert isinstance(result, Generator)
+
+    for chunk in result:
+        assert isinstance(chunk, LLMResultChunk)
+        assert isinstance(chunk.delta, LLMResultChunkDelta)
+        assert isinstance(chunk.delta.message, AssistantPromptMessage)
+        assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
+        if chunk.delta.finish_reason is not None:
+            assert chunk.delta.usage is not None
+            assert chunk.delta.usage.completion_tokens > 0
+
+
+def test_get_num_tokens():
+    model = CohereLargeLanguageModel()
+
+    num_tokens = model.get_num_tokens(
+        model='command-light',
+        credentials={
+            'api_key': os.environ.get('COHERE_API_KEY')
+        },
+        prompt_messages=[
+            UserPromptMessage(
+                content='Hello World!'
+            )
+        ]
+    )
+
+    assert num_tokens == 3
+
+    num_tokens = model.get_num_tokens(
+        model='command-light-chat',
+        credentials={
+            'api_key': os.environ.get('COHERE_API_KEY')
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content='You are a helpful AI assistant.',
+            ),
+            UserPromptMessage(
+                content='Hello World!'
+            )
+        ]
+    )
+
+    assert num_tokens == 15
+
+
+def test_fine_tuned_model():
+    model = CohereLargeLanguageModel()
+
+    # test invoke
+    result = model.invoke(
+        model='85ec47be-6139-4f75-a4be-0f0ec1ef115c-ft',
+        credentials={
+            'api_key': os.environ.get('COHERE_API_KEY'),
+            'mode': 'completion'
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content='You are a helpful AI assistant.',
+            ),
+            UserPromptMessage(
+                content='Hello World!'
+            )
+        ],
+        model_parameters={
+            'temperature': 0.0,
+            'max_tokens': 100
+        },
+        stream=False,
+        user="abc-123"
+    )
+
+    assert isinstance(result, LLMResult)
+
+
+def test_fine_tuned_chat_model():
+    model = CohereLargeLanguageModel()
+
+    # test invoke
+    result = model.invoke(
+        model='94f2d55a-4c79-4c00-bde4-23962e74b170-ft',
+        credentials={
+            'api_key': os.environ.get('COHERE_API_KEY'),
+            'mode': 'chat'
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content='You are a helpful AI assistant.',
+            ),
+            UserPromptMessage(
+                content='Hello World!'
+            )
+        ],
+        model_parameters={
+            'temperature': 0.0,
+            'max_tokens': 100
+        },
+        stream=False,
+        user="abc-123"
+    )
+
+    assert isinstance(result, LLMResult)

+ 64 - 0
api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py

@@ -0,0 +1,64 @@
+import os
+
+import pytest
+from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.cohere.text_embedding.text_embedding import CohereTextEmbeddingModel
+
+
+def test_validate_credentials():
+    model = CohereTextEmbeddingModel()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        model.validate_credentials(
+            model='embed-multilingual-v3.0',
+            credentials={
+                'api_key': 'invalid_key'
+            }
+        )
+
+    model.validate_credentials(
+        model='embed-multilingual-v3.0',
+        credentials={
+            'api_key': os.environ.get('COHERE_API_KEY')
+        }
+    )
+
+
+def test_invoke_model():
+    model = CohereTextEmbeddingModel()
+
+    result = model.invoke(
+        model='embed-multilingual-v3.0',
+        credentials={
+            'api_key': os.environ.get('COHERE_API_KEY')
+        },
+        texts=[
+            "hello",
+            "world",
+            " ".join(["long_text"] * 100),
+            " ".join(["another_long_text"] * 100)
+        ],
+        user="abc-123"
+    )
+
+    assert isinstance(result, TextEmbeddingResult)
+    assert len(result.embeddings) == 4
+    assert result.usage.total_tokens == 811
+
+
+def test_get_num_tokens():
+    model = CohereTextEmbeddingModel()
+
+    num_tokens = model.get_num_tokens(
+        model='embed-multilingual-v3.0',
+        credentials={
+            'api_key': os.environ.get('COHERE_API_KEY')
+        },
+        texts=[
+            "hello",
+            "world"
+        ]
+    )
+
+    assert num_tokens == 3