Browse Source

fix: minimax streaming function_call message (#4271)

Weaxs 11 months ago
parent
commit
8cc492721b

+ 31 - 45
api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py

@@ -20,16 +20,16 @@ class MinimaxChatCompletionPro:
         Minimax Chat Completion Pro API, supports function calling
         however, we do not have enough time and energy to implement it, but the parameters are reserved
     """
-    def generate(self, model: str, api_key: str, group_id: str, 
+    def generate(self, model: str, api_key: str, group_id: str,
                  prompt_messages: list[MinimaxMessage], model_parameters: dict,
                  tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \
-        -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
+            -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
         """
             generate chat completion
         """
         if not api_key or not group_id:
             raise InvalidAPIKeyError('Invalid API key or group ID')
-        
+
         url = f'https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}'
 
         extra_kwargs = {}
@@ -42,7 +42,7 @@ class MinimaxChatCompletionPro:
 
         if 'top_p' in model_parameters and type(model_parameters['top_p']) == float:
             extra_kwargs['top_p'] = model_parameters['top_p']
-        
+
         if 'plugin_web_search' in model_parameters and model_parameters['plugin_web_search']:
             extra_kwargs['plugins'] = [
                 'plugin_web_search'
@@ -61,7 +61,7 @@ class MinimaxChatCompletionPro:
         # check if there is a system message
         if len(prompt_messages) == 0:
             raise BadRequestError('At least one message is required')
-        
+
         if prompt_messages[0].role == MinimaxMessage.Role.SYSTEM.value:
             if prompt_messages[0].content:
                 bot_setting['content'] = prompt_messages[0].content
@@ -70,7 +70,7 @@ class MinimaxChatCompletionPro:
         # check if there is a user message
         if len(prompt_messages) == 0:
             raise BadRequestError('At least one user message is required')
-        
+
         messages = [message.to_dict() for message in prompt_messages]
 
         headers = {
@@ -89,21 +89,21 @@ class MinimaxChatCompletionPro:
 
         if tools:
             body['functions'] = tools
-            body['function_call'] = { 'type': 'auto' }
+            body['function_call'] = {'type': 'auto'}
 
         try:
             response = post(
                 url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300))
         except Exception as e:
             raise InternalServerError(e)
-        
+
         if response.status_code != 200:
             raise InternalServerError(response.text)
-        
+
         if stream:
             return self._handle_stream_chat_generate_response(response)
         return self._handle_chat_generate_response(response)
-    
+
     def _handle_error(self, code: int, msg: str):
         if code == 1000 or code == 1001 or code == 1013 or code == 1027:
             raise InternalServerError(msg)
@@ -127,7 +127,7 @@ class MinimaxChatCompletionPro:
             code = response['base_resp']['status_code']
             msg = response['base_resp']['status_msg']
             self._handle_error(code, msg)
-        
+
         message = MinimaxMessage(
             content=response['reply'],
             role=MinimaxMessage.Role.ASSISTANT.value
@@ -144,7 +144,6 @@ class MinimaxChatCompletionPro:
         """
             handle stream chat generate response
         """
-        function_call_storage = None
         for line in response.iter_lines():
             if not line:
                 continue
@@ -158,54 +157,41 @@ class MinimaxChatCompletionPro:
                 msg = data['base_resp']['status_msg']
                 self._handle_error(code, msg)
 
+            # final chunk
             if data['reply'] or 'usage' in data and data['usage']:
                 total_tokens = data['usage']['total_tokens']
-                message =  MinimaxMessage(
+                minimax_message = MinimaxMessage(
                     role=MinimaxMessage.Role.ASSISTANT.value,
                     content=''
                 )
-                message.usage = {
+                minimax_message.usage = {
                     'prompt_tokens': 0,
                     'completion_tokens': total_tokens,
                     'total_tokens': total_tokens
                 }
-                message.stop_reason = data['choices'][0]['finish_reason']
-
-                if function_call_storage:
-                    function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
-                    function_call_message.function_call = function_call_storage
-                    yield function_call_message
-
-                yield message
+                minimax_message.stop_reason = data['choices'][0]['finish_reason']
+
+                choices = data.get('choices', [])
+                if len(choices) > 0:
+                    for choice in choices:
+                        message = choice['messages'][0]
+                        # append function_call message
+                        if 'function_call' in message:
+                            function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
+                            function_call_message.function_call = message['function_call']
+                            yield function_call_message
+
+                yield minimax_message
                 return
 
+            # partial chunk
             choices = data.get('choices', [])
             if len(choices) == 0:
                 continue
 
             for choice in choices:
                 message = choice['messages'][0]
-
-                if 'function_call' in message:
-                    if not function_call_storage:
-                        function_call_storage = message['function_call']
-                        if 'arguments' not in function_call_storage or not function_call_storage['arguments']:
-                            function_call_storage['arguments'] = ''
-                            continue
-                    else:
-                        function_call_storage['arguments'] += message['function_call']['arguments']
-                        continue
-                else:
-                    if function_call_storage:
-                        message['function_call'] = function_call_storage
-                        function_call_storage = None
-                
-                minimax_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
-
-                if 'function_call' in message:
-                    minimax_message.function_call = message['function_call']
-
+                # append text message
                 if 'text' in message:
-                    minimax_message.content = message['text']
-
-                yield minimax_message
+                    minimax_message = MinimaxMessage(content=message['text'], role=MinimaxMessage.Role.ASSISTANT.value)
+                    yield minimax_message