|
@@ -3,7 +3,8 @@ from typing import Any, Dict, Generator, List, Optional, Union
|
|
|
|
|
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
|
|
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageRole,
|
|
|
- PromptMessageTool, SystemPromptMessage, UserPromptMessage)
|
|
|
+ PromptMessageTool, SystemPromptMessage, UserPromptMessage,
|
|
|
+ TextPromptMessageContent, ImagePromptMessageContent, PromptMessageContentType)
|
|
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
|
from core.model_runtime.model_providers.zhipuai._client import ZhipuModelAPI
|
|
@@ -108,10 +109,21 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|
|
prompt_messages = prompt_messages[1:]
|
|
|
|
|
|
# resolve zhipuai model not support system message and user message, assistant message must be in sequence
|
|
|
- new_prompt_messages = []
|
|
|
+ new_prompt_messages: List[PromptMessage] = []
|
|
|
for prompt_message in prompt_messages:
|
|
|
copy_prompt_message = prompt_message.copy()
|
|
|
if copy_prompt_message.role in [PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL]:
|
|
|
+ if isinstance(copy_prompt_message.content, list):
|
|
|
+ # check if model is 'glm-4v'
|
|
|
+ if model != 'glm-4v':
|
|
|
+ # not support list message
|
|
|
+ continue
|
|
|
+ # get image and
|
|
|
+ if not isinstance(copy_prompt_message, UserPromptMessage):
|
|
|
+ # not support system message
|
|
|
+ continue
|
|
|
+ new_prompt_messages.append(copy_prompt_message)
|
|
|
+
|
|
|
if not isinstance(copy_prompt_message.content, str):
|
|
|
# not support image message
|
|
|
continue
|
|
@@ -130,14 +142,41 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|
|
else:
|
|
|
new_prompt_messages.append(copy_prompt_message)
|
|
|
|
|
|
- params = {
|
|
|
- 'model': model,
|
|
|
- 'prompt': [{
|
|
|
- 'role': prompt_message.role.value,
|
|
|
- 'content': prompt_message.content
|
|
|
- } for prompt_message in new_prompt_messages],
|
|
|
- **model_parameters
|
|
|
- }
|
|
|
+ if model == 'glm-4v':
|
|
|
+ params = {
|
|
|
+ 'model': model,
|
|
|
+ 'prompt': [{
|
|
|
+ 'role': prompt_message.role.value,
|
|
|
+ 'content':
|
|
|
+ [
|
|
|
+ {
|
|
|
+ 'type': 'text',
|
|
|
+ 'text': prompt_message.content
|
|
|
+ }
|
|
|
+ ] if isinstance(prompt_message.content, str) else
|
|
|
+ [
|
|
|
+ {
|
|
|
+ 'type': 'image',
|
|
|
+ 'image_url': {
|
|
|
+ 'url': content.data
|
|
|
+ }
|
|
|
+ } if content.type == PromptMessageContentType.IMAGE else {
|
|
|
+ 'type': 'text',
|
|
|
+ 'text': content.data
|
|
|
+ } for content in prompt_message.content
|
|
|
+ ],
|
|
|
+ } for prompt_message in new_prompt_messages],
|
|
|
+ **model_parameters
|
|
|
+ }
|
|
|
+ else:
|
|
|
+ params = {
|
|
|
+ 'model': model,
|
|
|
+ 'prompt': [{
|
|
|
+ 'role': prompt_message.role.value,
|
|
|
+ 'content': prompt_message.content,
|
|
|
+ } for prompt_message in new_prompt_messages],
|
|
|
+ **model_parameters
|
|
|
+ }
|
|
|
|
|
|
if stream:
|
|
|
response = client.sse_invoke(incremental=True, **params).events()
|