Forráskód Böngészése

feat: support doubao llm function calling (#5100)

sino 10 hónapja
szülő
commit
0ce97e6315

+ 28 - 2
api/core/model_runtime/model_providers/volcengine_maas/client.py

@@ -7,7 +7,9 @@ from core.model_runtime.entities.message_entities import (
     ImagePromptMessageContent,
     PromptMessage,
     PromptMessageContentType,
+    PromptMessageTool,
     SystemPromptMessage,
+    ToolPromptMessage,
     UserPromptMessage,
 )
 from core.model_runtime.model_providers.volcengine_maas.errors import wrap_error
@@ -36,10 +38,11 @@ class MaaSClient(MaasService):
         client.set_sk(sk)
         return client
 
-    def chat(self, params: dict, messages: list[PromptMessage], stream=False) -> Generator | dict:
+    def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extra_model_kwargs) -> Generator | dict:
         req = {
             'parameters': params,
-            'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages]
+            'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages],
+            **extra_model_kwargs,
         }
         if not stream:
             return super().chat(
@@ -89,10 +92,22 @@ class MaaSClient(MaasService):
             message = cast(AssistantPromptMessage, message)
             message_dict = {'role': ChatRole.ASSISTANT,
                             'content': message.content}
+            if message.tool_calls:
+                message_dict['tool_calls'] = [
+                    {
+                        'name': call.function.name,
+                        'arguments': call.function.arguments
+                    } for call in message.tool_calls
+                ]
         elif isinstance(message, SystemPromptMessage):
             message = cast(SystemPromptMessage, message)
             message_dict = {'role': ChatRole.SYSTEM,
                             'content': message.content}
+        elif isinstance(message, ToolPromptMessage):
+            message = cast(ToolPromptMessage, message)
+            message_dict = {'role': ChatRole.FUNCTION,
+                            'content': message.content,
+                            'name': message.tool_call_id}
         else:
             raise ValueError(f"Got unknown PromptMessage type {message}")
 
@@ -106,3 +121,14 @@ class MaaSClient(MaasService):
             raise wrap_error(e)
 
         return resp
+
+    @staticmethod
+    def transform_tool_prompt_to_maas_config(tool: PromptMessageTool):
+        return {
+            "type": "function",
+            "function": {
+                "name": tool.name,
+                "description": tool.description,
+                "parameters": tool.parameters,
+            }
+        }

+ 29 - 3
api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py

@@ -119,8 +119,15 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
         if stop:
             req_params['stop'] = stop
 
+        extra_model_kwargs = {}
+        
+        if tools:
+            extra_model_kwargs['tools'] = [
+                MaaSClient.transform_tool_prompt_to_maas_config(tool) for tool in tools
+            ]
+
         resp = MaaSClient.wrap_exception(
-            lambda: client.chat(req_params, prompt_messages, stream))
+            lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs))
         if not stream:
             return self._handle_chat_response(model, credentials, prompt_messages, resp)
         return self._handle_stream_chat_response(model, credentials, prompt_messages, resp)
@@ -156,12 +163,26 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
         choice = choices[0]
         message = choice['message']
 
+        # parse tool calls
+        tool_calls = []
+        if message['tool_calls']:
+            for call in message['tool_calls']:
+                tool_call = AssistantPromptMessage.ToolCall(
+                    id=call['function']['name'],
+                    type=call['type'],
+                    function=AssistantPromptMessage.ToolCall.ToolCallFunction(
+                        name=call['function']['name'],
+                        arguments=call['function']['arguments']
+                    )
+                )
+                tool_calls.append(tool_call)
+
         return LLMResult(
             model=model,
             prompt_messages=prompt_messages,
             message=AssistantPromptMessage(
                 content=message['content'] if message['content'] else '',
-                tool_calls=[],
+                tool_calls=tool_calls,
             ),
             usage=self._calc_usage(model, credentials, resp['usage']),
         )
@@ -252,6 +273,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
         if credentials.get('context_size'):
             model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
                 credentials.get('context_size', 4096))
+
+        model_features = ModelConfigs.get(
+            credentials['base_model_name'], {}).get('features', [])
+
         entity = AIModelEntity(
             model=model,
             label=I18nObject(
@@ -260,7 +285,8 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
             fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
             model_type=ModelType.LLM,
             model_properties=model_properties,
-            parameter_rules=rules
+            parameter_rules=rules,
+            features=model_features,
         )
 
         return entity

+ 34 - 11
api/core/model_runtime/model_providers/volcengine_maas/llm/models.py

@@ -1,3 +1,5 @@
+from core.model_runtime.entities.model_entities import ModelFeature
+
 ModelConfigs = {
     'Doubao-pro-4k': {
         'req_params': {
@@ -7,7 +9,10 @@ ModelConfigs = {
         'model_properties': {
             'context_size': 4096,
             'mode': 'chat',
-        }
+        },
+        'features': [
+            ModelFeature.TOOL_CALL
+        ],
     },
     'Doubao-lite-4k': {
         'req_params': {
@@ -17,7 +22,10 @@ ModelConfigs = {
         'model_properties': {
             'context_size': 4096,
             'mode': 'chat',
-        }
+        },
+        'features': [
+            ModelFeature.TOOL_CALL
+        ],
     },
     'Doubao-pro-32k': {
         'req_params': {
@@ -27,7 +35,10 @@ ModelConfigs = {
         'model_properties': {
             'context_size': 32768,
             'mode': 'chat',
-        }
+        },
+        'features': [
+            ModelFeature.TOOL_CALL
+        ],
     },
     'Doubao-lite-32k': {
         'req_params': {
@@ -37,7 +48,10 @@ ModelConfigs = {
         'model_properties': {
             'context_size': 32768,
             'mode': 'chat',
-        }
+        },
+        'features': [
+            ModelFeature.TOOL_CALL
+        ],
     },
     'Doubao-pro-128k': {
         'req_params': {
@@ -47,7 +61,10 @@ ModelConfigs = {
         'model_properties': {
             'context_size': 131072,
             'mode': 'chat',
-        }
+        },
+        'features': [
+            ModelFeature.TOOL_CALL
+        ],
     },
     'Doubao-lite-128k': {
         'req_params': {
@@ -57,7 +74,10 @@ ModelConfigs = {
         'model_properties': {
             'context_size': 131072,
             'mode': 'chat',
-        }
+        },
+        'features': [
+            ModelFeature.TOOL_CALL
+        ],
     },
     'Skylark2-pro-4k': {
         'req_params': {
@@ -67,26 +87,29 @@ ModelConfigs = {
         'model_properties': {
             'context_size': 4096,
             'mode': 'chat',
-        }
+        },
+        'features': [],
     },
     'Llama3-8B': {
-         'req_params': {
+        'req_params': {
             'max_prompt_tokens': 8192,
             'max_new_tokens': 8192,
         },
         'model_properties': {
             'context_size': 8192,
             'mode': 'chat',
-        }
+        },
+        'features': [],
     },
     'Llama3-70B': {
-         'req_params': {
+        'req_params': {
             'max_prompt_tokens': 8192,
             'max_new_tokens': 8192,
         },
         'model_properties': {
             'context_size': 8192,
             'mode': 'chat',
-        }
+        },
+        'features': [],
     }
 }