Переглянути джерело

fix: Correct image parameter passing in GLM-4v model API calls (#2948)

Weishan-0 1 рік тому
батько
коміт
a676d4387c
1 змінених файлів з 51 додано та 37 видалено
  1. 51 37
      api/core/model_runtime/model_providers/zhipuai/llm/llm.py

+ 51 - 37
api/core/model_runtime/model_providers/zhipuai/llm/llm.py

@@ -5,6 +5,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
 from core.model_runtime.entities.message_entities import (
     AssistantPromptMessage,
     PromptMessage,
+    PromptMessageContent,
     PromptMessageContentType,
     PromptMessageRole,
     PromptMessageTool,
@@ -31,6 +32,7 @@ And you should always end the block with a "```" to indicate the end of the JSON
 
 ```JSON"""
 
+
 class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
 
     def _invoke(self, model: str, credentials: dict,
@@ -159,7 +161,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
 
         if len(prompt_messages) == 0:
             raise ValueError('At least one message is required')
-        
+
         if prompt_messages[0].role == PromptMessageRole.SYSTEM:
             if not prompt_messages[0].content:
                 prompt_messages = prompt_messages[1:]
@@ -185,7 +187,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
                     continue
 
                 if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER and \
-                    copy_prompt_message.role == PromptMessageRole.USER:
+                        copy_prompt_message.role == PromptMessageRole.USER:
                     new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
                 else:
                     if copy_prompt_message.role == PromptMessageRole.USER:
@@ -205,31 +207,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
                     new_prompt_messages.append(copy_prompt_message)
 
         if model == 'glm-4v':
-            params = {
-                'model': model,
-                'messages': [{
-                    'role': prompt_message.role.value,
-                    'content': 
-                        [
-                            {
-                                'type': 'text',
-                                'text': prompt_message.content
-                            }
-                        ] if isinstance(prompt_message.content, str) else 
-                        [
-                            {
-                                'type': 'image',
-                                'image_url': {
-                                    'url': content.data
-                                }
-                            } if content.type == PromptMessageContentType.IMAGE else {
-                                'type': 'text',
-                                'text': content.data
-                            } for content in prompt_message.content
-                        ],
-                } for prompt_message in new_prompt_messages],
-                **model_parameters
-            }
+            params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters)
         else:
             params = {
                 'model': model,
@@ -277,8 +255,8 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
                 for prompt_message in new_prompt_messages:
                     # merge system message to user message
                     if prompt_message.role == PromptMessageRole.SYSTEM or \
-                        prompt_message.role == PromptMessageRole.TOOL or \
-                        prompt_message.role == PromptMessageRole.USER:
+                            prompt_message.role == PromptMessageRole.TOOL or \
+                            prompt_message.role == PromptMessageRole.USER:
                         if len(params['messages']) > 0 and params['messages'][-1]['role'] == 'user':
                             params['messages'][-1]['content'] += "\n\n" + prompt_message.content
                         else:
@@ -306,8 +284,44 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
 
         response = client.chat.completions.create(**params, **extra_model_kwargs)
         return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages)
-        
-    def _handle_generate_response(self, model: str, 
+
+    def _construct_glm_4v_parameter(self, model: str, prompt_messages: list[PromptMessage],
+                                    model_parameters: dict):
+        messages = [
+            {
+                'role': message.role.value,
+                'content': self._construct_glm_4v_messages(message.content)
+            }
+            for message in prompt_messages
+        ]
+
+        params = {
+            'model': model,
+            'messages': messages,
+            **model_parameters
+        }
+
+        return params
+
+    def _construct_glm_4v_messages(self, prompt_message: Union[str | list[PromptMessageContent]]) -> list[dict]:
+        if isinstance(prompt_message, str):
+            return [{'type': 'text', 'text': prompt_message}]
+
+        return [
+            {'type': 'image_url', 'image_url': {'url': self._remove_image_header(item.data)}}
+            if item.type == PromptMessageContentType.IMAGE else
+            {'type': 'text', 'text': item.data}
+
+            for item in prompt_message
+        ]
+
+    def _remove_image_header(self, image: str) -> str:
+        if image.startswith('data:image'):
+            return image.split(',')[1]
+
+        return image
+
+    def _handle_generate_response(self, model: str,
                                   credentials: dict,
                                   tools: Optional[list[PromptMessageTool]],
                                   response: Completion,
@@ -338,7 +352,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
                         )
 
             text += choice.message.content or ''
-          
+
         prompt_usage = response.usage.prompt_tokens
         completion_usage = response.usage.completion_tokens
 
@@ -358,7 +372,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
 
         return result
 
-    def _handle_generate_stream_response(self, model: str, 
+    def _handle_generate_stream_response(self, model: str,
                                          credentials: dict,
                                          tools: Optional[list[PromptMessageTool]],
                                          responses: Generator[ChatCompletionChunk, None, None],
@@ -380,7 +394,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
 
             if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
                 continue
-            
+
             assistant_tool_calls: list[AssistantPromptMessage.ToolCall] = []
             for tool_call in delta.delta.tool_calls or []:
                 if tool_call.type == 'function':
@@ -454,8 +468,8 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
 
         return message_text
 
-
-    def _convert_messages_to_prompt(self, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> str:
+    def _convert_messages_to_prompt(self, messages: list[PromptMessage],
+                                    tools: Optional[list[PromptMessageTool]] = None) -> str:
         """
         :param messages: List of PromptMessage to combine.
         :return: Combined string with necessary human_prompt and ai_prompt tags.
@@ -473,4 +487,4 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
                 text += f"\n{tool.json()}"
 
         # trim off the trailing ' ' that might come from the "Assistant: "
-        return text.rstrip()
+        return text.rstrip()