فهرست منبع

feat: Add Cohere Command R / R+ model support (#3333)

takatost 1 سال پیش
والد
کامیت
826c422ac4

+ 2 - 0
api/core/model_runtime/model_providers/cohere/llm/_position.yaml

@@ -1,3 +1,5 @@
+- command-r
+- command-r-plus
 - command-chat
 - command-light-chat
 - command-nightly-chat

+ 1 - 1
api/core/model_runtime/model_providers/cohere/llm/command-chat.yaml

@@ -31,7 +31,7 @@ parameter_rules:
     max: 500
   - name: max_tokens
     use_template: max_tokens
-    default: 256
+    default: 1024
     max: 4096
   - name: preamble_override
     label:

+ 1 - 1
api/core/model_runtime/model_providers/cohere/llm/command-light-chat.yaml

@@ -31,7 +31,7 @@ parameter_rules:
     max: 500
   - name: max_tokens
     use_template: max_tokens
-    default: 256
+    default: 1024
     max: 4096
   - name: preamble_override
     label:

+ 1 - 1
api/core/model_runtime/model_providers/cohere/llm/command-light-nightly-chat.yaml

@@ -31,7 +31,7 @@ parameter_rules:
     max: 500
   - name: max_tokens
     use_template: max_tokens
-    default: 256
+    default: 1024
     max: 4096
   - name: preamble_override
     label:

+ 1 - 1
api/core/model_runtime/model_providers/cohere/llm/command-light-nightly.yaml

@@ -35,7 +35,7 @@ parameter_rules:
     use_template: frequency_penalty
   - name: max_tokens
     use_template: max_tokens
-    default: 256
+    default: 1024
     max: 4096
 pricing:
   input: '0.3'

+ 1 - 1
api/core/model_runtime/model_providers/cohere/llm/command-light.yaml

@@ -35,7 +35,7 @@ parameter_rules:
     use_template: frequency_penalty
   - name: max_tokens
     use_template: max_tokens
-    default: 256
+    default: 1024
     max: 4096
 pricing:
   input: '0.3'

+ 1 - 1
api/core/model_runtime/model_providers/cohere/llm/command-nightly-chat.yaml

@@ -31,7 +31,7 @@ parameter_rules:
     max: 500
   - name: max_tokens
     use_template: max_tokens
-    default: 256
+    default: 1024
     max: 4096
   - name: preamble_override
     label:

+ 1 - 1
api/core/model_runtime/model_providers/cohere/llm/command-nightly.yaml

@@ -35,7 +35,7 @@ parameter_rules:
     use_template: frequency_penalty
   - name: max_tokens
     use_template: max_tokens
-    default: 256
+    default: 1024
     max: 4096
 pricing:
   input: '1.0'

+ 45 - 0
api/core/model_runtime/model_providers/cohere/llm/command-r-plus.yaml

@@ -0,0 +1,45 @@
+model: command-r-plus
+label:
+  en_US: command-r-plus
+model_type: llm
+features:
+  - multi-tool-call
+  - agent-thought
+  - stream-tool-call
+model_properties:
+  mode: chat
+  context_size: 128000
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+    max: 5.0
+  - name: p
+    use_template: top_p
+    default: 0.75
+    min: 0.01
+    max: 0.99
+  - name: k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+    required: false
+    default: 0
+    min: 0
+    max: 500
+  - name: presence_penalty
+    use_template: presence_penalty
+  - name: frequency_penalty
+    use_template: frequency_penalty
+  - name: max_tokens
+    use_template: max_tokens
+    default: 1024
+    max: 4096
+pricing:
+  input: '3'
+  output: '15'
+  unit: '0.000001'
+  currency: USD

+ 45 - 0
api/core/model_runtime/model_providers/cohere/llm/command-r.yaml

@@ -0,0 +1,45 @@
+model: command-r
+label:
+  en_US: command-r
+model_type: llm
+features:
+  - multi-tool-call
+  - agent-thought
+  - stream-tool-call
+model_properties:
+  mode: chat
+  context_size: 128000
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+    max: 5.0
+  - name: p
+    use_template: top_p
+    default: 0.75
+    min: 0.01
+    max: 0.99
+  - name: k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+    required: false
+    default: 0
+    min: 0
+    max: 500
+  - name: presence_penalty
+    use_template: presence_penalty
+  - name: frequency_penalty
+    use_template: frequency_penalty
+  - name: max_tokens
+    use_template: max_tokens
+    default: 1024
+    max: 4096
+pricing:
+  input: '0.5'
+  output: '1.5'
+  unit: '0.000001'
+  currency: USD

+ 1 - 1
api/core/model_runtime/model_providers/cohere/llm/command.yaml

@@ -35,7 +35,7 @@ parameter_rules:
     use_template: frequency_penalty
   - name: max_tokens
     use_template: max_tokens
-    default: 256
+    default: 1024
     max: 4096
 pricing:
   input: '1.0'

+ 251 - 90
api/core/model_runtime/model_providers/cohere/llm/llm.py

@@ -1,20 +1,38 @@
+import json
 import logging
-from collections.abc import Generator
+from collections.abc import Generator, Iterator
 from typing import Optional, Union, cast
 
 import cohere
-from cohere.responses import Chat, Generations
-from cohere.responses.chat import StreamEnd, StreamingChat, StreamTextGeneration
-from cohere.responses.generation import StreamingGenerations, StreamingText
+from cohere import (
+    ChatMessage,
+    ChatStreamRequestToolResultsItem,
+    GenerateStreamedResponse,
+    GenerateStreamedResponse_StreamEnd,
+    GenerateStreamedResponse_StreamError,
+    GenerateStreamedResponse_TextGeneration,
+    Generation,
+    NonStreamedChatResponse,
+    StreamedChatResponse,
+    StreamedChatResponse_StreamEnd,
+    StreamedChatResponse_TextGeneration,
+    StreamedChatResponse_ToolCallsGeneration,
+    Tool,
+    ToolCall,
+    ToolParameterDefinitionsValue,
+)
+from cohere.core import RequestOptions
 
 from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
 from core.model_runtime.entities.message_entities import (
     AssistantPromptMessage,
     PromptMessage,
     PromptMessageContentType,
+    PromptMessageRole,
     PromptMessageTool,
     SystemPromptMessage,
     TextPromptMessageContent,
+    ToolPromptMessage,
     UserPromptMessage,
 )
 from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
@@ -64,6 +82,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
                 credentials=credentials,
                 prompt_messages=prompt_messages,
                 model_parameters=model_parameters,
+                tools=tools,
                 stop=stop,
                 stream=stream,
                 user=user
@@ -159,19 +178,26 @@ class CohereLargeLanguageModel(LargeLanguageModel):
         if stop:
             model_parameters['end_sequences'] = stop
 
-        response = client.generate(
-            prompt=prompt_messages[0].content,
-            model=model,
-            stream=stream,
-            **model_parameters,
-        )
-
         if stream:
+            response = client.generate_stream(
+                prompt=prompt_messages[0].content,
+                model=model,
+                **model_parameters,
+                request_options=RequestOptions(max_retries=0)
+            )
+
             return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
+        else:
+            response = client.generate(
+                prompt=prompt_messages[0].content,
+                model=model,
+                **model_parameters,
+                request_options=RequestOptions(max_retries=0)
+            )
 
-        return self._handle_generate_response(model, credentials, response, prompt_messages)
+            return self._handle_generate_response(model, credentials, response, prompt_messages)
 
-    def _handle_generate_response(self, model: str, credentials: dict, response: Generations,
+    def _handle_generate_response(self, model: str, credentials: dict, response: Generation,
                                   prompt_messages: list[PromptMessage]) \
             -> LLMResult:
         """
@@ -191,8 +217,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
         )
 
         # calculate num tokens
-        prompt_tokens = response.meta['billed_units']['input_tokens']
-        completion_tokens = response.meta['billed_units']['output_tokens']
+        prompt_tokens = int(response.meta.billed_units.input_tokens)
+        completion_tokens = int(response.meta.billed_units.output_tokens)
 
         # transform usage
         usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
@@ -207,7 +233,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
 
         return response
 
-    def _handle_generate_stream_response(self, model: str, credentials: dict, response: StreamingGenerations,
+    def _handle_generate_stream_response(self, model: str, credentials: dict, response: Iterator[GenerateStreamedResponse],
                                          prompt_messages: list[PromptMessage]) -> Generator:
         """
         Handle llm stream response
@@ -220,8 +246,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
         index = 1
         full_assistant_content = ''
         for chunk in response:
-            if isinstance(chunk, StreamingText):
-                chunk = cast(StreamingText, chunk)
+            if isinstance(chunk, GenerateStreamedResponse_TextGeneration):
+                chunk = cast(GenerateStreamedResponse_TextGeneration, chunk)
                 text = chunk.text
 
                 if text is None:
@@ -244,10 +270,16 @@ class CohereLargeLanguageModel(LargeLanguageModel):
                 )
 
                 index += 1
-            elif chunk is None:
+            elif isinstance(chunk, GenerateStreamedResponse_StreamEnd):
+                chunk = cast(GenerateStreamedResponse_StreamEnd, chunk)
+
                 # calculate num tokens
-                prompt_tokens = response.meta['billed_units']['input_tokens']
-                completion_tokens = response.meta['billed_units']['output_tokens']
+                prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
+                completion_tokens = self._num_tokens_from_messages(
+                    model,
+                    credentials,
+                    [AssistantPromptMessage(content=full_assistant_content)]
+                )
 
                 # transform usage
                 usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
@@ -258,14 +290,18 @@ class CohereLargeLanguageModel(LargeLanguageModel):
                     delta=LLMResultChunkDelta(
                         index=index,
                         message=AssistantPromptMessage(content=''),
-                        finish_reason=response.finish_reason,
+                        finish_reason=chunk.finish_reason,
                         usage=usage
                     )
                 )
                 break
+            elif isinstance(chunk, GenerateStreamedResponse_StreamError):
+                chunk = cast(GenerateStreamedResponse_StreamError, chunk)
+                raise InvokeBadRequestError(chunk.err)
 
     def _chat_generate(self, model: str, credentials: dict,
-                       prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
+                       prompt_messages: list[PromptMessage], model_parameters: dict,
+                       tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
                        stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
         """
         Invoke llm chat model
@@ -274,6 +310,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
         :param credentials: credentials
         :param prompt_messages: prompt messages
         :param model_parameters: model parameters
+        :param tools: tools for tool calling
         :param stop: stop words
         :param stream: is stream response
         :param user: unique user id
@@ -282,31 +319,46 @@ class CohereLargeLanguageModel(LargeLanguageModel):
         # initialize client
         client = cohere.Client(credentials.get('api_key'))
 
-        if user:
-            model_parameters['user_name'] = user
+        if stop:
+            model_parameters['stop_sequences'] = stop
+
+        if tools:
+            model_parameters['tools'] = self._convert_tools(tools)
 
-        message, chat_histories = self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages)
+        message, chat_histories, tool_results \
+            = self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages)
+
+        if tool_results:
+            model_parameters['tool_results'] = tool_results
 
         # chat model
         real_model = model
         if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL:
             real_model = model.removesuffix('-chat')
 
-        response = client.chat(
-            message=message,
-            chat_history=chat_histories,
-            model=real_model,
-            stream=stream,
-            **model_parameters,
-        )
-
         if stream:
-            return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, stop)
+            response = client.chat_stream(
+                message=message,
+                chat_history=chat_histories,
+                model=real_model,
+                **model_parameters,
+                request_options=RequestOptions(max_retries=0)
+            )
 
-        return self._handle_chat_generate_response(model, credentials, response, prompt_messages, stop)
+            return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
+        else:
+            response = client.chat(
+                message=message,
+                chat_history=chat_histories,
+                model=real_model,
+                **model_parameters,
+                request_options=RequestOptions(max_retries=0)
+            )
 
-    def _handle_chat_generate_response(self, model: str, credentials: dict, response: Chat,
-                                       prompt_messages: list[PromptMessage], stop: Optional[list[str]] = None) \
+            return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
+
+    def _handle_chat_generate_response(self, model: str, credentials: dict, response: NonStreamedChatResponse,
+                                       prompt_messages: list[PromptMessage]) \
             -> LLMResult:
         """
         Handle llm chat response
@@ -315,14 +367,27 @@ class CohereLargeLanguageModel(LargeLanguageModel):
         :param credentials: credentials
         :param response: response
         :param prompt_messages: prompt messages
-        :param stop: stop words
         :return: llm response
         """
         assistant_text = response.text
 
+        tool_calls = []
+        if response.tool_calls:
+            for cohere_tool_call in response.tool_calls:
+                tool_call = AssistantPromptMessage.ToolCall(
+                    id=cohere_tool_call.name,
+                    type='function',
+                    function=AssistantPromptMessage.ToolCall.ToolCallFunction(
+                        name=cohere_tool_call.name,
+                        arguments=json.dumps(cohere_tool_call.parameters)
+                    )
+                )
+                tool_calls.append(tool_call)
+
         # transform assistant message to prompt message
         assistant_prompt_message = AssistantPromptMessage(
-            content=assistant_text
+            content=assistant_text,
+            tool_calls=tool_calls
         )
 
         # calculate num tokens
@@ -332,44 +397,38 @@ class CohereLargeLanguageModel(LargeLanguageModel):
         # transform usage
         usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
 
-        if stop:
-            # enforce stop tokens
-            assistant_text = self.enforce_stop_tokens(assistant_text, stop)
-            assistant_prompt_message = AssistantPromptMessage(
-                content=assistant_text
-            )
-
         # transform response
         response = LLMResult(
             model=model,
             prompt_messages=prompt_messages,
             message=assistant_prompt_message,
-            usage=usage,
-            system_fingerprint=response.preamble
+            usage=usage
         )
 
         return response
 
-    def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: StreamingChat,
-                                              prompt_messages: list[PromptMessage],
-                                              stop: Optional[list[str]] = None) -> Generator:
+    def _handle_chat_generate_stream_response(self, model: str, credentials: dict,
+                                              response: Iterator[StreamedChatResponse],
+                                              prompt_messages: list[PromptMessage]) -> Generator:
         """
         Handle llm chat stream response
 
         :param model: model name
         :param response: response
         :param prompt_messages: prompt messages
-        :param stop: stop words
         :return: llm response chunk generator
         """
 
-        def final_response(full_text: str, index: int, finish_reason: Optional[str] = None,
-                           preamble: Optional[str] = None) -> LLMResultChunk:
+        def final_response(full_text: str,
+                           tool_calls: list[AssistantPromptMessage.ToolCall],
+                           index: int,
+                           finish_reason: Optional[str] = None) -> LLMResultChunk:
             # calculate num tokens
             prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
 
             full_assistant_prompt_message = AssistantPromptMessage(
-                content=full_text
+                content=full_text,
+                tool_calls=tool_calls
             )
             completion_tokens = self._num_tokens_from_messages(model, credentials, [full_assistant_prompt_message])
 
@@ -379,10 +438,9 @@ class CohereLargeLanguageModel(LargeLanguageModel):
             return LLMResultChunk(
                 model=model,
                 prompt_messages=prompt_messages,
-                system_fingerprint=preamble,
                 delta=LLMResultChunkDelta(
                     index=index,
-                    message=AssistantPromptMessage(content=''),
+                    message=AssistantPromptMessage(content='', tool_calls=tool_calls),
                     finish_reason=finish_reason,
                     usage=usage
                 )
@@ -390,9 +448,10 @@ class CohereLargeLanguageModel(LargeLanguageModel):
 
         index = 1
         full_assistant_content = ''
+        tool_calls = []
         for chunk in response:
-            if isinstance(chunk, StreamTextGeneration):
-                chunk = cast(StreamTextGeneration, chunk)
+            if isinstance(chunk, StreamedChatResponse_TextGeneration):
+                chunk = cast(StreamedChatResponse_TextGeneration, chunk)
                 text = chunk.text
 
                 if text is None:
@@ -403,12 +462,6 @@ class CohereLargeLanguageModel(LargeLanguageModel):
                     content=text
                 )
 
-                # stop
-                # notice: This logic can only cover few stop scenarios
-                if stop and text in stop:
-                    yield final_response(full_assistant_content, index, 'stop')
-                    break
-
                 full_assistant_content += text
 
                 yield LLMResultChunk(
@@ -421,39 +474,98 @@ class CohereLargeLanguageModel(LargeLanguageModel):
                 )
 
                 index += 1
-            elif isinstance(chunk, StreamEnd):
-                chunk = cast(StreamEnd, chunk)
-                yield final_response(full_assistant_content, index, chunk.finish_reason, response.preamble)
+            elif isinstance(chunk, StreamedChatResponse_ToolCallsGeneration):
+                chunk = cast(StreamedChatResponse_ToolCallsGeneration, chunk)
+
+                tool_calls = []
+                if chunk.tool_calls:
+                    for cohere_tool_call in chunk.tool_calls:
+                        tool_call = AssistantPromptMessage.ToolCall(
+                            id=cohere_tool_call.name,
+                            type='function',
+                            function=AssistantPromptMessage.ToolCall.ToolCallFunction(
+                                name=cohere_tool_call.name,
+                                arguments=json.dumps(cohere_tool_call.parameters)
+                            )
+                        )
+                        tool_calls.append(tool_call)
+            elif isinstance(chunk, StreamedChatResponse_StreamEnd):
+                chunk = cast(StreamedChatResponse_StreamEnd, chunk)
+                yield final_response(full_assistant_content, tool_calls, index, chunk.finish_reason)
                 index += 1
 
     def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages: list[PromptMessage]) \
-            -> tuple[str, list[dict]]:
+            -> tuple[str, list[ChatMessage], list[ChatStreamRequestToolResultsItem]]:
         """
         Convert prompt messages to message and chat histories
         :param prompt_messages: prompt messages
         :return:
         """
         chat_histories = []
+        latest_tool_call_n_outputs = []
         for prompt_message in prompt_messages:
-            chat_histories.append(self._convert_prompt_message_to_dict(prompt_message))
+            if prompt_message.role == PromptMessageRole.ASSISTANT:
+                prompt_message = cast(AssistantPromptMessage, prompt_message)
+                if prompt_message.tool_calls:
+                    for tool_call in prompt_message.tool_calls:
+                        latest_tool_call_n_outputs.append(ChatStreamRequestToolResultsItem(
+                            call=ToolCall(
+                                name=tool_call.function.name,
+                                parameters=json.loads(tool_call.function.arguments)
+                            ),
+                            outputs=[]
+                        ))
+                else:
+                    cohere_prompt_message = self._convert_prompt_message_to_dict(prompt_message)
+                    if cohere_prompt_message:
+                        chat_histories.append(cohere_prompt_message)
+            elif prompt_message.role == PromptMessageRole.TOOL:
+                prompt_message = cast(ToolPromptMessage, prompt_message)
+                if latest_tool_call_n_outputs:
+                    i = 0
+                    for tool_call_n_outputs in latest_tool_call_n_outputs:
+                        if tool_call_n_outputs.call.name == prompt_message.tool_call_id:
+                            latest_tool_call_n_outputs[i] = ChatStreamRequestToolResultsItem(
+                                call=ToolCall(
+                                    name=tool_call_n_outputs.call.name,
+                                    parameters=tool_call_n_outputs.call.parameters
+                                ),
+                                outputs=[{
+                                    "result": prompt_message.content
+                                }]
+                            )
+                            break
+                        i += 1
+            else:
+                cohere_prompt_message = self._convert_prompt_message_to_dict(prompt_message)
+                if cohere_prompt_message:
+                    chat_histories.append(cohere_prompt_message)
+
+        if latest_tool_call_n_outputs:
+            new_latest_tool_call_n_outputs = []
+            for tool_call_n_outputs in latest_tool_call_n_outputs:
+                if tool_call_n_outputs.outputs:
+                    new_latest_tool_call_n_outputs.append(tool_call_n_outputs)
+
+            latest_tool_call_n_outputs = new_latest_tool_call_n_outputs
 
         # get latest message from chat histories and pop it
         if len(chat_histories) > 0:
             latest_message = chat_histories.pop()
-            message = latest_message['message']
+            message = latest_message.message
         else:
             raise ValueError('Prompt messages is empty')
 
-        return message, chat_histories
+        return message, chat_histories, latest_tool_call_n_outputs
 
-    def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
+    def _convert_prompt_message_to_dict(self, message: PromptMessage) -> Optional[ChatMessage]:
         """
         Convert PromptMessage to dict for Cohere model
         """
         if isinstance(message, UserPromptMessage):
             message = cast(UserPromptMessage, message)
             if isinstance(message.content, str):
-                message_dict = {"role": "USER", "message": message.content}
+                chat_message = ChatMessage(role="USER", message=message.content)
             else:
                 sub_message_text = ''
                 for message_content in message.content:
@@ -461,20 +573,57 @@ class CohereLargeLanguageModel(LargeLanguageModel):
                         message_content = cast(TextPromptMessageContent, message_content)
                         sub_message_text += message_content.data
 
-                message_dict = {"role": "USER", "message": sub_message_text}
+                chat_message = ChatMessage(role="USER", message=sub_message_text)
         elif isinstance(message, AssistantPromptMessage):
             message = cast(AssistantPromptMessage, message)
-            message_dict = {"role": "CHATBOT", "message": message.content}
+            if not message.content:
+                return None
+            chat_message = ChatMessage(role="CHATBOT", message=message.content)
         elif isinstance(message, SystemPromptMessage):
             message = cast(SystemPromptMessage, message)
-            message_dict = {"role": "USER", "message": message.content}
+            chat_message = ChatMessage(role="USER", message=message.content)
+        elif isinstance(message, ToolPromptMessage):
+            return None
         else:
             raise ValueError(f"Got unknown type {message}")
 
-        if message.name:
-            message_dict["user_name"] = message.name
+        return chat_message
+
+    def _convert_tools(self, tools: list[PromptMessageTool]) -> list[Tool]:
+        """
+        Convert tools to Cohere model
+        """
+        cohere_tools = []
+        for tool in tools:
+            properties = tool.parameters['properties']
+            required_properties = tool.parameters['required']
+
+            parameter_definitions = {}
+            for p_key, p_val in properties.items():
+                required = False
+                if property in required_properties:
+                    required = True
+
+                desc = p_val['description']
+                if 'enum' in p_val:
+                    desc += (f"; Only accepts one of the following predefined options: "
+                             f"[{', '.join(p_val['enum'])}]")
+
+                parameter_definitions[p_key] = ToolParameterDefinitionsValue(
+                    description=desc,
+                    type=p_val['type'],
+                    required=required
+                )
 
-        return message_dict
+            cohere_tool = Tool(
+                name=tool.name,
+                description=tool.description,
+                parameter_definitions=parameter_definitions
+            )
+
+            cohere_tools.append(cohere_tool)
+
+        return cohere_tools
 
     def _num_tokens_from_string(self, model: str, credentials: dict, text: str) -> int:
         """
@@ -493,12 +642,16 @@ class CohereLargeLanguageModel(LargeLanguageModel):
             model=model
         )
 
-        return response.length
+        return len(response.tokens)
 
     def _num_tokens_from_messages(self, model: str, credentials: dict, messages: list[PromptMessage]) -> int:
         """Calculate num tokens Cohere model."""
-        messages = [self._convert_prompt_message_to_dict(m) for m in messages]
-        message_strs = [f"{message['role']}: {message['message']}" for message in messages]
+        calc_messages = []
+        for message in messages:
+            cohere_message = self._convert_prompt_message_to_dict(message)
+            if cohere_message:
+                calc_messages.append(cohere_message)
+        message_strs = [f"{message.role}: {message.message}" for message in calc_messages]
         message_str = "\n".join(message_strs)
 
         real_model = model
@@ -564,13 +717,21 @@ class CohereLargeLanguageModel(LargeLanguageModel):
         """
         return {
             InvokeConnectionError: [
-                cohere.CohereConnectionError
+                cohere.errors.service_unavailable_error.ServiceUnavailableError
+            ],
+            InvokeServerUnavailableError: [
+                cohere.errors.internal_server_error.InternalServerError
+            ],
+            InvokeRateLimitError: [
+                cohere.errors.too_many_requests_error.TooManyRequestsError
+            ],
+            InvokeAuthorizationError: [
+                cohere.errors.unauthorized_error.UnauthorizedError,
+                cohere.errors.forbidden_error.ForbiddenError
             ],
-            InvokeServerUnavailableError: [],
-            InvokeRateLimitError: [],
-            InvokeAuthorizationError: [],
             InvokeBadRequestError: [
-                cohere.CohereAPIError,
-                cohere.CohereError,
+                cohere.core.api_error.ApiError,
+                cohere.errors.bad_request_error.BadRequestError,
+                cohere.errors.not_found_error.NotFoundError,
             ]
         }

+ 21 - 10
api/core/model_runtime/model_providers/cohere/rerank/rerank.py

@@ -1,6 +1,7 @@
 from typing import Optional
 
 import cohere
+from cohere.core import RequestOptions
 
 from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
 from core.model_runtime.errors.invoke import (
@@ -44,19 +45,21 @@ class CohereRerankModel(RerankModel):
 
         # initialize client
         client = cohere.Client(credentials.get('api_key'))
-        results = client.rerank(
+        response = client.rerank(
             query=query,
             documents=docs,
             model=model,
-            top_n=top_n
+            top_n=top_n,
+            return_documents=True,
+            request_options=RequestOptions(max_retries=0)
         )
 
         rerank_documents = []
-        for idx, result in enumerate(results):
+        for idx, result in enumerate(response.results):
             # format document
             rerank_document = RerankDocument(
                 index=result.index,
-                text=result.document['text'],
+                text=result.document.text,
                 score=result.relevance_score,
             )
 
@@ -108,13 +111,21 @@ class CohereRerankModel(RerankModel):
         """
         return {
             InvokeConnectionError: [
-                cohere.CohereConnectionError,
+                cohere.errors.service_unavailable_error.ServiceUnavailableError
+            ],
+            InvokeServerUnavailableError: [
+                cohere.errors.internal_server_error.InternalServerError
+            ],
+            InvokeRateLimitError: [
+                cohere.errors.too_many_requests_error.TooManyRequestsError
+            ],
+            InvokeAuthorizationError: [
+                cohere.errors.unauthorized_error.UnauthorizedError,
+                cohere.errors.forbidden_error.ForbiddenError
             ],
-            InvokeServerUnavailableError: [],
-            InvokeRateLimitError: [],
-            InvokeAuthorizationError: [],
             InvokeBadRequestError: [
-                cohere.CohereAPIError,
-                cohere.CohereError,
+                cohere.core.api_error.ApiError,
+                cohere.errors.bad_request_error.BadRequestError,
+                cohere.errors.not_found_error.NotFoundError,
             ]
         }

+ 27 - 16
api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py

@@ -3,7 +3,7 @@ from typing import Optional
 
 import cohere
 import numpy as np
-from cohere.responses import Tokens
+from cohere.core import RequestOptions
 
 from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
@@ -52,8 +52,8 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
                 text=text
             )
 
-            for j in range(0, tokenize_response.length, context_size):
-                tokens += [tokenize_response.token_strings[j: j + context_size]]
+            for j in range(0, len(tokenize_response), context_size):
+                tokens += [tokenize_response[j: j + context_size]]
                 indices += [i]
 
         batched_embeddings = []
@@ -127,9 +127,9 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
         except Exception as e:
             raise self._transform_invoke_error(e)
 
-        return response.length
+        return len(response)
 
-    def _tokenize(self, model: str, credentials: dict, text: str) -> Tokens:
+    def _tokenize(self, model: str, credentials: dict, text: str) -> list[str]:
         """
         Tokenize text
         :param model: model name
@@ -138,17 +138,19 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
         :return:
         """
         if not text:
-            return Tokens([], [], {})
+            return []
 
         # initialize client
         client = cohere.Client(credentials.get('api_key'))
 
         response = client.tokenize(
             text=text,
-            model=model
+            model=model,
+            offline=False,
+            request_options=RequestOptions(max_retries=0)
         )
 
-        return response
+        return response.token_strings
 
     def validate_credentials(self, model: str, credentials: dict) -> None:
         """
@@ -184,10 +186,11 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
         response = client.embed(
             texts=texts,
             model=model,
-            input_type='search_document' if len(texts) > 1 else 'search_query'
+            input_type='search_document' if len(texts) > 1 else 'search_query',
+            request_options=RequestOptions(max_retries=1)
         )
 
-        return response.embeddings, response.meta['billed_units']['input_tokens']
+        return response.embeddings, int(response.meta.billed_units.input_tokens)
 
     def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
         """
@@ -231,13 +234,21 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
         """
         return {
             InvokeConnectionError: [
-                cohere.CohereConnectionError
+                cohere.errors.service_unavailable_error.ServiceUnavailableError
+            ],
+            InvokeServerUnavailableError: [
+                cohere.errors.internal_server_error.InternalServerError
+            ],
+            InvokeRateLimitError: [
+                cohere.errors.too_many_requests_error.TooManyRequestsError
+            ],
+            InvokeAuthorizationError: [
+                cohere.errors.unauthorized_error.UnauthorizedError,
+                cohere.errors.forbidden_error.ForbiddenError
             ],
-            InvokeServerUnavailableError: [],
-            InvokeRateLimitError: [],
-            InvokeAuthorizationError: [],
             InvokeBadRequestError: [
-                cohere.CohereAPIError,
-                cohere.CohereError,
+                cohere.core.api_error.ApiError,
+                cohere.errors.bad_request_error.BadRequestError,
+                cohere.errors.not_found_error.NotFoundError,
             ]
         }

+ 2 - 2
api/core/prompt/simple_prompt_transform.py

@@ -232,8 +232,8 @@ class SimplePromptTransform(PromptTransform):
                     )
                 ),
                 max_token_limit=rest_tokens,
-                ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human',
-                human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
+                human_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human',
+                ai_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
             )
 
             # get prompt

+ 3 - 2
api/requirements.txt

@@ -47,7 +47,8 @@ replicate~=0.22.0
 websocket-client~=1.7.0
 dashscope[tokenizer]~=1.14.0
 huggingface_hub~=0.16.4
-transformers~=4.31.0
+transformers~=4.35.0
+tokenizers~=0.15.0
 pandas==1.5.3
 xinference-client==0.9.4
 safetensors==0.3.2
@@ -55,7 +56,7 @@ zhipuai==1.0.7
 werkzeug~=3.0.1
 pymilvus==2.3.0
 qdrant-client==1.7.3
-cohere~=4.44
+cohere~=5.2.4
 pyyaml~=6.0.1
 numpy~=1.25.2
 unstructured[docx,pptx,msg,md,ppt]~=0.10.27