|
@@ -5,6 +5,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
|
|
|
from core.model_runtime.entities.message_entities import (
|
|
|
AssistantPromptMessage,
|
|
|
PromptMessage,
|
|
|
+ PromptMessageContent,
|
|
|
PromptMessageContentType,
|
|
|
PromptMessageRole,
|
|
|
PromptMessageTool,
|
|
@@ -31,6 +32,7 @@ And you should always end the block with a "```" to indicate the end of the JSON
|
|
|
|
|
|
```JSON"""
|
|
|
|
|
|
+
|
|
|
class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|
|
|
|
|
def _invoke(self, model: str, credentials: dict,
|
|
@@ -159,7 +161,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|
|
|
|
|
if len(prompt_messages) == 0:
|
|
|
raise ValueError('At least one message is required')
|
|
|
-
|
|
|
+
|
|
|
if prompt_messages[0].role == PromptMessageRole.SYSTEM:
|
|
|
if not prompt_messages[0].content:
|
|
|
prompt_messages = prompt_messages[1:]
|
|
@@ -185,7 +187,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|
|
continue
|
|
|
|
|
|
if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER and \
|
|
|
- copy_prompt_message.role == PromptMessageRole.USER:
|
|
|
+ copy_prompt_message.role == PromptMessageRole.USER:
|
|
|
new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
|
|
|
else:
|
|
|
if copy_prompt_message.role == PromptMessageRole.USER:
|
|
@@ -205,31 +207,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|
|
new_prompt_messages.append(copy_prompt_message)
|
|
|
|
|
|
if model == 'glm-4v':
|
|
|
- params = {
|
|
|
- 'model': model,
|
|
|
- 'messages': [{
|
|
|
- 'role': prompt_message.role.value,
|
|
|
- 'content':
|
|
|
- [
|
|
|
- {
|
|
|
- 'type': 'text',
|
|
|
- 'text': prompt_message.content
|
|
|
- }
|
|
|
- ] if isinstance(prompt_message.content, str) else
|
|
|
- [
|
|
|
- {
|
|
|
- 'type': 'image',
|
|
|
- 'image_url': {
|
|
|
- 'url': content.data
|
|
|
- }
|
|
|
- } if content.type == PromptMessageContentType.IMAGE else {
|
|
|
- 'type': 'text',
|
|
|
- 'text': content.data
|
|
|
- } for content in prompt_message.content
|
|
|
- ],
|
|
|
- } for prompt_message in new_prompt_messages],
|
|
|
- **model_parameters
|
|
|
- }
|
|
|
+ params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters)
|
|
|
else:
|
|
|
params = {
|
|
|
'model': model,
|
|
@@ -277,8 +255,8 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|
|
for prompt_message in new_prompt_messages:
|
|
|
# merge system message to user message
|
|
|
if prompt_message.role == PromptMessageRole.SYSTEM or \
|
|
|
- prompt_message.role == PromptMessageRole.TOOL or \
|
|
|
- prompt_message.role == PromptMessageRole.USER:
|
|
|
+ prompt_message.role == PromptMessageRole.TOOL or \
|
|
|
+ prompt_message.role == PromptMessageRole.USER:
|
|
|
if len(params['messages']) > 0 and params['messages'][-1]['role'] == 'user':
|
|
|
params['messages'][-1]['content'] += "\n\n" + prompt_message.content
|
|
|
else:
|
|
@@ -306,8 +284,44 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|
|
|
|
|
response = client.chat.completions.create(**params, **extra_model_kwargs)
|
|
|
return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages)
|
|
|
-
|
|
|
- def _handle_generate_response(self, model: str,
|
|
|
+
|
|
|
+ def _construct_glm_4v_parameter(self, model: str, prompt_messages: list[PromptMessage],
|
|
|
+ model_parameters: dict):
|
|
|
+ messages = [
|
|
|
+ {
|
|
|
+ 'role': message.role.value,
|
|
|
+ 'content': self._construct_glm_4v_messages(message.content)
|
|
|
+ }
|
|
|
+ for message in prompt_messages
|
|
|
+ ]
|
|
|
+
|
|
|
+ params = {
|
|
|
+ 'model': model,
|
|
|
+ 'messages': messages,
|
|
|
+ **model_parameters
|
|
|
+ }
|
|
|
+
|
|
|
+ return params
|
|
|
+
|
|
|
+ def _construct_glm_4v_messages(self, prompt_message: Union[str | list[PromptMessageContent]]) -> list[dict]:
|
|
|
+ if isinstance(prompt_message, str):
|
|
|
+ return [{'type': 'text', 'text': prompt_message}]
|
|
|
+
|
|
|
+ return [
|
|
|
+ {'type': 'image_url', 'image_url': {'url': self._remove_image_header(item.data)}}
|
|
|
+ if item.type == PromptMessageContentType.IMAGE else
|
|
|
+ {'type': 'text', 'text': item.data}
|
|
|
+
|
|
|
+ for item in prompt_message
|
|
|
+ ]
|
|
|
+
|
|
|
+ def _remove_image_header(self, image: str) -> str:
|
|
|
+ if image.startswith('data:image'):
|
|
|
+ return image.split(',')[1]
|
|
|
+
|
|
|
+ return image
|
|
|
+
|
|
|
+ def _handle_generate_response(self, model: str,
|
|
|
credentials: dict,
|
|
|
tools: Optional[list[PromptMessageTool]],
|
|
|
response: Completion,
|
|
@@ -338,7 +352,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|
|
)
|
|
|
|
|
|
text += choice.message.content or ''
|
|
|
-
|
|
|
+
|
|
|
prompt_usage = response.usage.prompt_tokens
|
|
|
completion_usage = response.usage.completion_tokens
|
|
|
|
|
@@ -358,7 +372,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|
|
|
|
|
return result
|
|
|
|
|
|
- def _handle_generate_stream_response(self, model: str,
|
|
|
+ def _handle_generate_stream_response(self, model: str,
|
|
|
credentials: dict,
|
|
|
tools: Optional[list[PromptMessageTool]],
|
|
|
responses: Generator[ChatCompletionChunk, None, None],
|
|
@@ -380,7 +394,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|
|
|
|
|
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
|
|
|
continue
|
|
|
-
|
|
|
+
|
|
|
assistant_tool_calls: list[AssistantPromptMessage.ToolCall] = []
|
|
|
for tool_call in delta.delta.tool_calls or []:
|
|
|
if tool_call.type == 'function':
|
|
@@ -454,8 +468,8 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|
|
|
|
|
return message_text
|
|
|
|
|
|
-
|
|
|
- def _convert_messages_to_prompt(self, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> str:
|
|
|
+ def _convert_messages_to_prompt(self, messages: list[PromptMessage],
|
|
|
+ tools: Optional[list[PromptMessageTool]] = None) -> str:
|
|
|
"""
|
|
|
:param messages: List of PromptMessage to combine.
|
|
|
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
|
@@ -473,4 +487,4 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|
|
text += f"\n{tool.json()}"
|
|
|
|
|
|
# trim off the trailing ' ' that might come from the "Assistant: "
|
|
|
- return text.rstrip()
|
|
|
+ return text.rstrip()
|