Ver código fonte

feat: bedrock invoke enhancement (#6808)

longzhihun 8 meses atrás
pai
commit
9ce5cea911

+ 1 - 0
api/core/model_runtime/model_providers/bedrock/llm/_position.yaml

@@ -12,6 +12,7 @@
 - cohere.command-r-v1.0
 - meta.llama3-1-8b-instruct-v1:0
 - meta.llama3-1-70b-instruct-v1:0
+- meta.llama3-1-405b-instruct-v1:0
 - meta.llama3-8b-instruct-v1:0
 - meta.llama3-70b-instruct-v1:0
 - meta.llama2-13b-chat-v1

+ 1 - 2
api/core/model_runtime/model_providers/bedrock/llm/cohere.command-r-plus-v1.0.yaml

@@ -3,8 +3,7 @@ label:
   en_US: Command R+
 model_type: llm
 features:
-  #- multi-tool-call
-  - agent-thought
+  - tool-call
   #- stream-tool-call
 model_properties:
   mode: chat

+ 1 - 3
api/core/model_runtime/model_providers/bedrock/llm/cohere.command-r-v1.0.yaml

@@ -3,9 +3,7 @@ label:
   en_US: Command R
 model_type: llm
 features:
-  #- multi-tool-call
-  - agent-thought
-  #- stream-tool-call
+  - tool-call
 model_properties:
   mode: chat
   context_size: 128000

+ 9 - 175
api/core/model_runtime/model_providers/bedrock/llm/llm.py

@@ -17,7 +17,6 @@ from botocore.exceptions import (
     ServiceNotInRegionError,
     UnknownServiceError,
 )
-from cohere import ChatMessage
 
 # local import
 from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
@@ -42,7 +41,6 @@ from core.model_runtime.errors.invoke import (
 )
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
-from core.model_runtime.model_providers.cohere.llm.llm import CohereLargeLanguageModel
 
 logger = logging.getLogger(__name__)
 
@@ -59,6 +57,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         {'prefix': 'mistral.mixtral-8x7b-instruct', 'support_system_prompts': False, 'support_tool_use': False},
         {'prefix': 'mistral.mistral-large', 'support_system_prompts': True, 'support_tool_use': True},
         {'prefix': 'mistral.mistral-small', 'support_system_prompts': True, 'support_tool_use': True},
+        {'prefix': 'cohere.command-r', 'support_system_prompts': True, 'support_tool_use': True},
         {'prefix': 'amazon.titan', 'support_system_prompts': False, 'support_tool_use': False}
     ]
 
@@ -94,86 +93,8 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
             model_info['model'] = model
             # invoke models via boto3 converse API
             return self._generate_with_converse(model_info, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
-        # invoke Cohere models via boto3 client
-        if "cohere.command-r" in model:
-            return self._generate_cohere_chat(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
         # invoke other models via boto3 client
         return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
-    
-    def _generate_cohere_chat(
-            self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
-            stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
-            tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]:
-        cohere_llm = CohereLargeLanguageModel()
-        client_config = Config(
-            region_name=credentials["aws_region"]
-        )
-
-        runtime_client = boto3.client(
-            service_name='bedrock-runtime',
-            config=client_config,
-            aws_access_key_id=credentials["aws_access_key_id"],
-            aws_secret_access_key=credentials["aws_secret_access_key"]
-        )
-
-        extra_model_kwargs = {}
-        if stop:
-            extra_model_kwargs['stop_sequences'] = stop
-
-        if tools:
-            tools = cohere_llm._convert_tools(tools)
-            model_parameters['tools'] = tools
-
-        message, chat_histories, tool_results \
-            = cohere_llm._convert_prompt_messages_to_message_and_chat_histories(prompt_messages)
-
-        if tool_results:
-            model_parameters['tool_results'] = tool_results
-
-        payload = {
-            **model_parameters,
-            "message": message,
-            "chat_history": chat_histories,
-        }
-
-        # need workaround for ai21 models which doesn't support streaming
-        if stream:
-            invoke = runtime_client.invoke_model_with_response_stream
-        else:
-            invoke = runtime_client.invoke_model
-
-        def serialize(obj):
-            if isinstance(obj, ChatMessage):
-                return obj.__dict__
-            raise TypeError(f"Type {type(obj)} not serializable")
-
-        try:
-            body_jsonstr=json.dumps(payload, default=serialize)
-            response = invoke(
-                modelId=model,
-                contentType="application/json",
-                accept="*/*",
-                body=body_jsonstr
-            )
-        except ClientError as ex:
-            error_code = ex.response['Error']['Code']
-            full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
-            raise self._map_client_to_invoke_error(error_code, full_error_msg)
-
-        except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex:
-            raise InvokeConnectionError(str(ex))
-
-        except UnknownServiceError as ex:
-            raise InvokeServerUnavailableError(str(ex))
-
-        except Exception as ex:
-            raise InvokeError(str(ex))
-
-        if stream:
-            return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
-
-        return self._handle_generate_response(model, credentials, response, prompt_messages)
-
 
     def _generate_with_converse(self, model_info: dict, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
                 stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]:
@@ -581,38 +502,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         :param message: PromptMessage to convert.
         :return: String representation of the message.
         """
-        
-        if model_prefix == "anthropic":
-            human_prompt_prefix = "\n\nHuman:"
-            human_prompt_postfix = ""
-            ai_prompt = "\n\nAssistant:"
-
-        elif model_prefix == "meta":
-            # LLAMA3
-            if model_name.startswith("llama3"):
-                human_prompt_prefix = "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
-                human_prompt_postfix = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
-                ai_prompt = "\n\nAssistant:"
-            else:
-                # LLAMA2
-                human_prompt_prefix = "\n[INST]"
-                human_prompt_postfix = "[\\INST]\n"
-                ai_prompt = ""
-
-        elif model_prefix == "mistral":
-            human_prompt_prefix = "<s>[INST]"
-            human_prompt_postfix = "[\\INST]\n"
-            ai_prompt = "\n\nAssistant:"
-
-        elif model_prefix == "amazon":
-            human_prompt_prefix = "\n\nUser:"
-            human_prompt_postfix = ""
-            ai_prompt = "\n\nBot:"
-        
-        else:
-            human_prompt_prefix = ""
-            human_prompt_postfix = ""
-            ai_prompt = ""
+        human_prompt_prefix = ""
+        human_prompt_postfix = ""
+        ai_prompt = ""
 
         content = message.content
 
@@ -663,13 +555,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         model_prefix = model.split('.')[0]
         model_name = model.split('.')[1]
 
-        if model_prefix == "amazon":
-            payload["textGenerationConfig"] = { **model_parameters }
-            payload["textGenerationConfig"]["stopSequences"] = ["User:"]
-            
-            payload["inputText"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
-        
-        elif model_prefix == "ai21":
+        if model_prefix == "ai21":
             payload["temperature"] = model_parameters.get("temperature")
             payload["topP"] = model_parameters.get("topP")
             payload["maxTokens"] = model_parameters.get("maxTokens")
@@ -681,28 +567,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
                 payload["frequencyPenalty"] = {model_parameters.get("frequencyPenalty")}
             if model_parameters.get("countPenalty"):
                 payload["countPenalty"] = {model_parameters.get("countPenalty")}
-        
-        elif model_prefix == "mistral":
-            payload["temperature"] = model_parameters.get("temperature")
-            payload["top_p"] = model_parameters.get("top_p")
-            payload["max_tokens"] = model_parameters.get("max_tokens")
-            payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
-            payload["stop"] = stop[:10] if stop else []
-
-        elif model_prefix == "anthropic":
-            payload = { **model_parameters }
-            payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
-            payload["stop_sequences"] = ["\n\nHuman:"] + (stop if stop else [])
-            
+                    
         elif model_prefix == "cohere":
             payload = { **model_parameters }
             payload["prompt"] = prompt_messages[0].content
             payload["stream"] = stream
         
-        elif model_prefix == "meta":
-            payload = { **model_parameters }
-            payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix, model_name)
-
         else:
             raise ValueError(f"Got unknown model prefix {model_prefix}")
         
@@ -793,36 +663,16 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         # get output text and calculate num tokens based on model / provider
         model_prefix = model.split('.')[0]
 
-        if model_prefix == "amazon":
-            output = response_body.get("results")[0].get("outputText").strip('\n')
-            prompt_tokens = response_body.get("inputTextTokenCount")
-            completion_tokens = response_body.get("results")[0].get("tokenCount")
-
-        elif model_prefix == "ai21":
+        if model_prefix == "ai21":
             output = response_body.get('completions')[0].get('data').get('text')
             prompt_tokens = len(response_body.get("prompt").get("tokens"))
             completion_tokens = len(response_body.get('completions')[0].get('data').get('tokens'))
-
-        elif model_prefix == "anthropic":
-            output = response_body.get("completion")
-            prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
-            completion_tokens = self.get_num_tokens(model, credentials, output if output else '')
             
         elif model_prefix == "cohere":
             output = response_body.get("generations")[0].get("text")
             prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
             completion_tokens = self.get_num_tokens(model, credentials, output if output else '')
-            
-        elif model_prefix == "meta":
-            output = response_body.get("generation").strip('\n')
-            prompt_tokens = response_body.get("prompt_token_count")
-            completion_tokens = response_body.get("generation_token_count")
-        
-        elif model_prefix == "mistral":
-            output = response_body.get("outputs")[0].get("text")
-            prompt_tokens = response.get('ResponseMetadata').get('HTTPHeaders').get('x-amzn-bedrock-input-token-count')
-            completion_tokens = response.get('ResponseMetadata').get('HTTPHeaders').get('x-amzn-bedrock-output-token-count')
-
+                        
         else:
             raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
 
@@ -893,26 +743,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
             payload = json.loads(chunk.get('bytes').decode())
 
             model_prefix = model.split('.')[0]
-            if model_prefix == "amazon":
-                content_delta = payload.get("outputText").strip('\n')
-                finish_reason = payload.get("completion_reason")
- 
-            elif model_prefix == "anthropic":
-                content_delta = payload.get("completion")
-                finish_reason = payload.get("stop_reason")
-
-            elif model_prefix == "cohere":
+            if model_prefix == "cohere":
                 content_delta = payload.get("text")
                 finish_reason = payload.get("finish_reason")
             
-            elif model_prefix == "mistral":
-                content_delta = payload.get('outputs')[0].get("text")
-                finish_reason = payload.get('outputs')[0].get("stop_reason")
-
-            elif model_prefix == "meta":
-                content_delta = payload.get("generation").strip('\n')
-                finish_reason = payload.get("stop_reason")
-            
             else:
                 raise ValueError(f"Got unknown model prefix {model_prefix} when handling stream response")
 

+ 25 - 0
api/core/model_runtime/model_providers/bedrock/llm/meta.llama3-1-405b-instruct-v1.0.yaml

@@ -0,0 +1,25 @@
+model: meta.llama3-1-405b-instruct-v1:0
+label:
+  en_US: Llama 3.1 405B Instruct
+model_type: llm
+model_properties:
+  mode: completion
+  context_size: 128000
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+    default: 0.5
+  - name: top_p
+    use_template: top_p
+    default: 0.9
+  - name: max_gen_len
+    use_template: max_tokens
+    required: true
+    default: 512
+    min: 1
+    max: 2048
+pricing:
+  input: '0.00532'
+  output: '0.016'
+  unit: '0.001'
+  currency: USD