Browse Source

fix: minimax streaming function_call message (#4271)

Weaxs 1 year ago
parent
commit
8cc492721b

+ 31 - 45
api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py

@@ -20,16 +20,16 @@ class MinimaxChatCompletionPro:
         Minimax Chat Completion Pro API, supports function calling
         Minimax Chat Completion Pro API, supports function calling
         however, we do not have enough time and energy to implement it, but the parameters are reserved
         however, we do not have enough time and energy to implement it, but the parameters are reserved
     """
     """
-    def generate(self, model: str, api_key: str, group_id: str, 
+    def generate(self, model: str, api_key: str, group_id: str,
                  prompt_messages: list[MinimaxMessage], model_parameters: dict,
                  prompt_messages: list[MinimaxMessage], model_parameters: dict,
                  tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \
                  tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \
-        -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
+            -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
         """
         """
             generate chat completion
             generate chat completion
         """
         """
         if not api_key or not group_id:
         if not api_key or not group_id:
             raise InvalidAPIKeyError('Invalid API key or group ID')
             raise InvalidAPIKeyError('Invalid API key or group ID')
-        
+
         url = f'https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}'
         url = f'https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}'
 
 
         extra_kwargs = {}
         extra_kwargs = {}
@@ -42,7 +42,7 @@ class MinimaxChatCompletionPro:
 
 
         if 'top_p' in model_parameters and type(model_parameters['top_p']) == float:
         if 'top_p' in model_parameters and type(model_parameters['top_p']) == float:
             extra_kwargs['top_p'] = model_parameters['top_p']
             extra_kwargs['top_p'] = model_parameters['top_p']
-        
+
         if 'plugin_web_search' in model_parameters and model_parameters['plugin_web_search']:
         if 'plugin_web_search' in model_parameters and model_parameters['plugin_web_search']:
             extra_kwargs['plugins'] = [
             extra_kwargs['plugins'] = [
                 'plugin_web_search'
                 'plugin_web_search'
@@ -61,7 +61,7 @@ class MinimaxChatCompletionPro:
         # check if there is a system message
         # check if there is a system message
         if len(prompt_messages) == 0:
         if len(prompt_messages) == 0:
             raise BadRequestError('At least one message is required')
             raise BadRequestError('At least one message is required')
-        
+
         if prompt_messages[0].role == MinimaxMessage.Role.SYSTEM.value:
         if prompt_messages[0].role == MinimaxMessage.Role.SYSTEM.value:
             if prompt_messages[0].content:
             if prompt_messages[0].content:
                 bot_setting['content'] = prompt_messages[0].content
                 bot_setting['content'] = prompt_messages[0].content
@@ -70,7 +70,7 @@ class MinimaxChatCompletionPro:
         # check if there is a user message
         # check if there is a user message
         if len(prompt_messages) == 0:
         if len(prompt_messages) == 0:
             raise BadRequestError('At least one user message is required')
             raise BadRequestError('At least one user message is required')
-        
+
         messages = [message.to_dict() for message in prompt_messages]
         messages = [message.to_dict() for message in prompt_messages]
 
 
         headers = {
         headers = {
@@ -89,21 +89,21 @@ class MinimaxChatCompletionPro:
 
 
         if tools:
         if tools:
             body['functions'] = tools
             body['functions'] = tools
-            body['function_call'] = { 'type': 'auto' }
+            body['function_call'] = {'type': 'auto'}
 
 
         try:
         try:
             response = post(
             response = post(
                 url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300))
                 url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300))
         except Exception as e:
         except Exception as e:
             raise InternalServerError(e)
             raise InternalServerError(e)
-        
+
         if response.status_code != 200:
         if response.status_code != 200:
             raise InternalServerError(response.text)
             raise InternalServerError(response.text)
-        
+
         if stream:
         if stream:
             return self._handle_stream_chat_generate_response(response)
             return self._handle_stream_chat_generate_response(response)
         return self._handle_chat_generate_response(response)
         return self._handle_chat_generate_response(response)
-    
+
     def _handle_error(self, code: int, msg: str):
     def _handle_error(self, code: int, msg: str):
         if code == 1000 or code == 1001 or code == 1013 or code == 1027:
         if code == 1000 or code == 1001 or code == 1013 or code == 1027:
             raise InternalServerError(msg)
             raise InternalServerError(msg)
@@ -127,7 +127,7 @@ class MinimaxChatCompletionPro:
             code = response['base_resp']['status_code']
             code = response['base_resp']['status_code']
             msg = response['base_resp']['status_msg']
             msg = response['base_resp']['status_msg']
             self._handle_error(code, msg)
             self._handle_error(code, msg)
-        
+
         message = MinimaxMessage(
         message = MinimaxMessage(
             content=response['reply'],
             content=response['reply'],
             role=MinimaxMessage.Role.ASSISTANT.value
             role=MinimaxMessage.Role.ASSISTANT.value
@@ -144,7 +144,6 @@ class MinimaxChatCompletionPro:
         """
         """
             handle stream chat generate response
             handle stream chat generate response
         """
         """
-        function_call_storage = None
         for line in response.iter_lines():
         for line in response.iter_lines():
             if not line:
             if not line:
                 continue
                 continue
@@ -158,54 +157,41 @@ class MinimaxChatCompletionPro:
                 msg = data['base_resp']['status_msg']
                 msg = data['base_resp']['status_msg']
                 self._handle_error(code, msg)
                 self._handle_error(code, msg)
 
 
+            # final chunk
             if data['reply'] or 'usage' in data and data['usage']:
             if data['reply'] or 'usage' in data and data['usage']:
                 total_tokens = data['usage']['total_tokens']
                 total_tokens = data['usage']['total_tokens']
-                message =  MinimaxMessage(
+                minimax_message = MinimaxMessage(
                     role=MinimaxMessage.Role.ASSISTANT.value,
                     role=MinimaxMessage.Role.ASSISTANT.value,
                     content=''
                     content=''
                 )
                 )
-                message.usage = {
+                minimax_message.usage = {
                     'prompt_tokens': 0,
                     'prompt_tokens': 0,
                     'completion_tokens': total_tokens,
                     'completion_tokens': total_tokens,
                     'total_tokens': total_tokens
                     'total_tokens': total_tokens
                 }
                 }
-                message.stop_reason = data['choices'][0]['finish_reason']
-
-                if function_call_storage:
-                    function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
-                    function_call_message.function_call = function_call_storage
-                    yield function_call_message
-
-                yield message
+                minimax_message.stop_reason = data['choices'][0]['finish_reason']
+
+                choices = data.get('choices', [])
+                if len(choices) > 0:
+                    for choice in choices:
+                        message = choice['messages'][0]
+                        # append function_call message
+                        if 'function_call' in message:
+                            function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
+                            function_call_message.function_call = message['function_call']
+                            yield function_call_message
+
+                yield minimax_message
                 return
                 return
 
 
+            # partial chunk
             choices = data.get('choices', [])
             choices = data.get('choices', [])
             if len(choices) == 0:
             if len(choices) == 0:
                 continue
                 continue
 
 
             for choice in choices:
             for choice in choices:
                 message = choice['messages'][0]
                 message = choice['messages'][0]
-
-                if 'function_call' in message:
-                    if not function_call_storage:
-                        function_call_storage = message['function_call']
-                        if 'arguments' not in function_call_storage or not function_call_storage['arguments']:
-                            function_call_storage['arguments'] = ''
-                            continue
-                    else:
-                        function_call_storage['arguments'] += message['function_call']['arguments']
-                        continue
-                else:
-                    if function_call_storage:
-                        message['function_call'] = function_call_storage
-                        function_call_storage = None
-                
-                minimax_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
-
-                if 'function_call' in message:
-                    minimax_message.function_call = message['function_call']
-
+                # append text message
                 if 'text' in message:
                 if 'text' in message:
-                    minimax_message.content = message['text']
-
-                yield minimax_message
+                    minimax_message = MinimaxMessage(content=message['text'], role=MinimaxMessage.Role.ASSISTANT.value)
+                    yield minimax_message