Browse Source

feat: bedrock model runtime enhancement (#6299)

longzhihun 9 tháng trước cách đây
mục cha
commit
ed9e692263

+ 73 - 20
api/core/model_runtime/model_providers/bedrock/llm/llm.py

@@ -48,6 +48,28 @@ logger = logging.getLogger(__name__)
 
 class BedrockLargeLanguageModel(LargeLanguageModel):
 
+    # please refer to the documentation: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html
+    # TODO There is invoke issue: context limit on Cohere Model, will add them after fixed.
+    CONVERSE_API_ENABLED_MODEL_INFO=[
+        {'prefix': 'anthropic.claude-v2', 'support_system_prompts': True, 'support_tool_use': False},
+        {'prefix': 'anthropic.claude-v1', 'support_system_prompts': True, 'support_tool_use': False},
+        {'prefix': 'anthropic.claude-3', 'support_system_prompts': True, 'support_tool_use': True},
+        {'prefix': 'meta.llama', 'support_system_prompts': True, 'support_tool_use': False},
+        {'prefix': 'mistral.mistral-7b-instruct', 'support_system_prompts': False, 'support_tool_use': False},
+        {'prefix': 'mistral.mixtral-8x7b-instruct', 'support_system_prompts': False, 'support_tool_use': False},
+        {'prefix': 'mistral.mistral-large', 'support_system_prompts': True, 'support_tool_use': True},
+        {'prefix': 'mistral.mistral-small', 'support_system_prompts': True, 'support_tool_use': True},
+        {'prefix': 'amazon.titan', 'support_system_prompts': False, 'support_tool_use': False}
+    ]
+
+    @staticmethod
+    def _find_model_info(model_id):
+        for model in BedrockLargeLanguageModel.CONVERSE_API_ENABLED_MODEL_INFO:
+            if model_id.startswith(model['prefix']):
+                return model
+        logger.info(f"current model id: {model_id} did not support by Converse API")
+        return None
+
     def _invoke(self, model: str, credentials: dict,
                 prompt_messages: list[PromptMessage], model_parameters: dict,
                 tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
@@ -66,10 +88,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         :param user: unique user id
         :return: full response or stream response chunk generator result
         """
-        # TODO: consolidate different invocation methods for models based on base model capabilities
-        # invoke anthropic models via boto3 client
-        if "anthropic" in model:
-            return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
+        
+        model_info= BedrockLargeLanguageModel._find_model_info(model)
+        if model_info:
+            model_info['model'] = model
+            # invoke models via boto3 converse API
+            return self._generate_with_converse(model_info, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
         # invoke Cohere models via boto3 client
         if "cohere.command-r" in model:
             return self._generate_cohere_chat(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
@@ -151,12 +175,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         return self._handle_generate_response(model, credentials, response, prompt_messages)
 
 
-    def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
+    def _generate_with_converse(self, model_info: dict, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
                 stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]:
         """
-        Invoke Anthropic large language model
+        Invoke large language model with converse API
 
-        :param model: model name
+        :param model_info: model information
         :param credentials: model credentials
         :param prompt_messages: prompt messages
         :param model_parameters: model parameters
@@ -173,24 +197,24 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop)
 
         parameters = {
-            'modelId': model,
+            'modelId': model_info['model'],
             'messages': prompt_message_dicts,
             'inferenceConfig': inference_config,
             'additionalModelRequestFields': additional_model_fields,
         }
 
-        if system and len(system) > 0:
+        if model_info['support_system_prompts'] and system and len(system) > 0:
             parameters['system'] = system
 
-        if tools:
+        if model_info['support_tool_use'] and tools:
             parameters['toolConfig'] = self._convert_converse_tool_config(tools=tools)
 
         if stream:
             response = bedrock_client.converse_stream(**parameters)
-            return self._handle_converse_stream_response(model, credentials, response, prompt_messages)
+            return self._handle_converse_stream_response(model_info['model'], credentials, response, prompt_messages)
         else:
             response = bedrock_client.converse(**parameters)
-            return self._handle_converse_response(model, credentials, response, prompt_messages)
+            return self._handle_converse_response(model_info['model'], credentials, response, prompt_messages)
 
     def _handle_converse_response(self, model: str, credentials: dict, response: dict,
                                 prompt_messages: list[PromptMessage]) -> LLMResult:
@@ -203,10 +227,30 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         :param prompt_messages: prompt messages
         :return: full response chunk generator result
         """
+        response_content = response['output']['message']['content']
         # transform assistant message to prompt message
-        assistant_prompt_message = AssistantPromptMessage(
-            content=response['output']['message']['content'][0]['text']
-        )
+        if response['stopReason'] == 'tool_use':
+            tool_calls = []
+            text, tool_use = self._extract_tool_use(response_content)
+
+            tool_call = AssistantPromptMessage.ToolCall(
+                id=tool_use['toolUseId'],
+                type='function',
+                function=AssistantPromptMessage.ToolCall.ToolCallFunction(
+                    name=tool_use['name'],
+                    arguments=json.dumps(tool_use['input'])
+                )
+            )
+            tool_calls.append(tool_call)
+
+            assistant_prompt_message = AssistantPromptMessage(
+                content=text,
+                tool_calls=tool_calls
+            )
+        else:
+            assistant_prompt_message = AssistantPromptMessage(
+                content=response_content[0]['text']
+            )
 
         # calculate num tokens
         if response['usage']:
@@ -229,6 +273,18 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         )
         return result
 
+    def _extract_tool_use(self, content:dict)-> tuple[str, dict]:
+        tool_use = {}
+        text = ''
+        for item in content:
+            if 'toolUse' in item:
+                tool_use = item['toolUse']
+            elif 'text' in item:
+                text = item['text']
+            else:
+                raise ValueError(f"Got unknown item: {item}")
+        return text, tool_use
+
     def _handle_converse_stream_response(self, model: str, credentials: dict, response: dict,
                                         prompt_messages: list[PromptMessage], ) -> Generator:
         """
@@ -340,14 +396,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         """
 
         system = []
+        prompt_message_dicts = []
         for message in prompt_messages:
             if isinstance(message, SystemPromptMessage):
                 message.content=message.content.strip()
                 system.append({"text": message.content})
-
-        prompt_message_dicts = []
-        for message in prompt_messages:
-            if not isinstance(message, SystemPromptMessage):
+            else:
                 prompt_message_dicts.append(self._convert_prompt_message_to_dict(message))
 
         return system, prompt_message_dicts
@@ -448,7 +502,6 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
             }
         else:
             raise ValueError(f"Got unknown type {message}")
-
         return message_dict
 
     def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage] | str,

+ 3 - 0
api/core/model_runtime/model_providers/bedrock/llm/mistral.mistral-large-2402-v1.0.yaml

@@ -2,6 +2,9 @@ model: mistral.mistral-large-2402-v1:0
 label:
   en_US: Mistral Large
 model_type: llm
+features:
+  - tool-call
+  - agent-thought
 model_properties:
   mode: completion
   context_size: 32000

+ 2 - 0
api/core/model_runtime/model_providers/bedrock/llm/mistral.mistral-small-2402-v1.0.yaml

@@ -2,6 +2,8 @@ model: mistral.mistral-small-2402-v1:0
 label:
   en_US: Mistral Small
 model_type: llm
+features:
+  - tool-call
 model_properties:
   mode: completion
   context_size: 32000