Преглед на файлове

Add bedrock command r models (#4521)

Co-authored-by: Justin Wu <justin.wu@ringcentral.com>
Co-authored-by: Chenhe Gu <guchenhe@gmail.com>
Justin Wu преди 10 месеца
родител
ревизия
61f4f08744

+ 4 - 0
api/core/agent/output_parser/cot_output_parser.py

@@ -17,6 +17,10 @@ class CotAgentOutputParser:
                 action_name = None
                 action_input = None
 
+                # cohere always returns a list
+                if isinstance(action, list) and len(action) == 1:
+                    action = action[0]
+
                 for key, value in action.items():
                     if 'input' in key.lower():
                         action_input = value

+ 2 - 0
api/core/model_runtime/model_providers/bedrock/llm/_position.yaml

@@ -8,6 +8,8 @@
 - anthropic.claude-3-haiku-v1:0
 - cohere.command-light-text-v14
 - cohere.command-text-v14
+- cohere.command-r-plus-v1.0
+- cohere.command-r-v1.0
 - meta.llama3-8b-instruct-v1:0
 - meta.llama3-70b-instruct-v1:0
 - meta.llama2-13b-chat-v1

+ 45 - 0
api/core/model_runtime/model_providers/bedrock/llm/cohere.command-r-plus-v1.0.yaml

@@ -0,0 +1,45 @@
+model: cohere.command-r-plus-v1:0
+label:
+  en_US: Command R+
+model_type: llm
+features:
+  #- multi-tool-call
+  - agent-thought
+  #- stream-tool-call
+model_properties:
+  mode: chat
+  context_size: 128000
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+    max: 5.0
+  - name: p
+    use_template: top_p
+    default: 0.75
+    min: 0.01
+    max: 0.99
+  - name: k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+    required: false
+    default: 0
+    min: 0
+    max: 500
+  - name: presence_penalty
+    use_template: presence_penalty
+  - name: frequency_penalty
+    use_template: frequency_penalty
+  - name: max_tokens
+    use_template: max_tokens
+    default: 1024
+    max: 4096
+pricing:
+  input: '3'
+  output: '15'
+  unit: '0.000001'
+  currency: USD

+ 45 - 0
api/core/model_runtime/model_providers/bedrock/llm/cohere.command-r-v1.0.yaml

@@ -0,0 +1,45 @@
+model: cohere.command-r-v1:0
+label:
+  en_US: Command R
+model_type: llm
+features:
+  #- multi-tool-call
+  - agent-thought
+  #- stream-tool-call
+model_properties:
+  mode: chat
+  context_size: 128000
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+    max: 5.0
+  - name: p
+    use_template: top_p
+    default: 0.75
+    min: 0.01
+    max: 0.99
+  - name: k
+    label:
+      zh_Hans: 取样数量
+      en_US: Top k
+    type: int
+    help:
+      zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
+      en_US: Only sample from the top K options for each subsequent token.
+    required: false
+    default: 0
+    min: 0
+    max: 500
+  - name: presence_penalty
+    use_template: presence_penalty
+  - name: frequency_penalty
+    use_template: frequency_penalty
+  - name: max_tokens
+    use_template: max_tokens
+    default: 1024
+    max: 4096
+pricing:
+  input: '0.5'
+  output: '1.5'
+  unit: '0.000001'
+  currency: USD

+ 80 - 0
api/core/model_runtime/model_providers/bedrock/llm/llm.py

@@ -25,6 +25,7 @@ from botocore.exceptions import (
     ServiceNotInRegionError,
     UnknownServiceError,
 )
+from cohere import ChatMessage
 
 from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
 from core.model_runtime.entities.message_entities import (
@@ -48,6 +49,7 @@ from core.model_runtime.errors.invoke import (
 )
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.model_runtime.model_providers.cohere.llm.llm import CohereLargeLanguageModel
 
 logger = logging.getLogger(__name__)
 
@@ -75,8 +77,86 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         # invoke anthropic models via anthropic official SDK
         if "anthropic" in model:
             return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user)
+        # invoke Cohere models via boto3 client
+        if "cohere.command-r" in model:
+            return self._generate_cohere_chat(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
         # invoke other models via boto3 client
         return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
+    
+    def _generate_cohere_chat(
+            self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
+            stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
+            tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]:
+        cohere_llm = CohereLargeLanguageModel()
+        client_config = Config(
+            region_name=credentials["aws_region"]
+        )
+
+        runtime_client = boto3.client(
+            service_name='bedrock-runtime',
+            config=client_config,
+            aws_access_key_id=credentials["aws_access_key_id"],
+            aws_secret_access_key=credentials["aws_secret_access_key"]
+        )
+
+        extra_model_kwargs = {}
+        if stop:
+            extra_model_kwargs['stop_sequences'] = stop
+
+        if tools:
+            tools = cohere_llm._convert_tools(tools)
+            model_parameters['tools'] = tools
+
+        message, chat_histories, tool_results \
+            = cohere_llm._convert_prompt_messages_to_message_and_chat_histories(prompt_messages)
+
+        if tool_results:
+            model_parameters['tool_results'] = tool_results
+
+        payload = {
+            **model_parameters,
+            "message": message,
+            "chat_history": chat_histories,
+        }
+
+        # need workaround for ai21 models which doesn't support streaming
+        if stream:
+            invoke = runtime_client.invoke_model_with_response_stream
+        else:
+            invoke = runtime_client.invoke_model
+
+        def serialize(obj):
+            if isinstance(obj, ChatMessage):
+                return obj.__dict__
+            raise TypeError(f"Type {type(obj)} not serializable")
+
+        try:
+            body_jsonstr=json.dumps(payload, default=serialize)
+            response = invoke(
+                modelId=model,
+                contentType="application/json",
+                accept="*/*",
+                body=body_jsonstr
+            )
+        except ClientError as ex:
+            error_code = ex.response['Error']['Code']
+            full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
+            raise self._map_client_to_invoke_error(error_code, full_error_msg)
+
+        except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex:
+            raise InvokeConnectionError(str(ex))
+
+        except UnknownServiceError as ex:
+            raise InvokeServerUnavailableError(str(ex))
+
+        except Exception as ex:
+            raise InvokeError(str(ex))
+
+        if stream:
+            return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
+
+        return self._handle_generate_response(model, credentials, response, prompt_messages)
+
 
     def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
                 stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: