|
@@ -3,13 +3,15 @@ from typing import Any, Dict, Generator, List, Optional, Union
|
|
|
|
|
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
|
|
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageRole,
|
|
|
- PromptMessageTool, SystemPromptMessage, UserPromptMessage,
|
|
|
+ PromptMessageTool, SystemPromptMessage, UserPromptMessage, ToolPromptMessage,
|
|
|
TextPromptMessageContent, ImagePromptMessageContent, PromptMessageContentType)
|
|
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
|
+from core.model_runtime.utils import helper
|
|
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
|
-from core.model_runtime.model_providers.zhipuai._client import ZhipuModelAPI
|
|
|
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
|
|
|
-
|
|
|
+from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
|
|
|
+from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk
|
|
|
+from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion import Completion
|
|
|
|
|
|
class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|
|
|
|
@@ -35,7 +37,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|
|
credentials_kwargs = self._to_credential_kwargs(credentials)
|
|
|
|
|
|
# invoke model
|
|
|
- return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, stop, stream, user)
|
|
|
+ return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user)
|
|
|
|
|
|
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
|
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
|
@@ -48,7 +50,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|
|
:param tools: tools for tool calling
|
|
|
:return:
|
|
|
"""
|
|
|
- prompt = self._convert_messages_to_prompt(prompt_messages)
|
|
|
+ prompt = self._convert_messages_to_prompt(prompt_messages, tools)
|
|
|
|
|
|
return self._get_num_tokens_by_gpt2(prompt)
|
|
|
|
|
@@ -72,6 +74,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|
|
model_parameters={
|
|
|
"temperature": 0.5,
|
|
|
},
|
|
|
+ tools=[],
|
|
|
stream=False
|
|
|
)
|
|
|
except Exception as ex:
|
|
@@ -79,6 +82,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|
|
|
|
|
def _generate(self, model: str, credentials_kwargs: dict,
|
|
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
|
|
+ tools: Optional[list[PromptMessageTool]] = None,
|
|
|
stop: Optional[List[str]] = None, stream: bool = True,
|
|
|
user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
|
|
"""
|
|
@@ -97,7 +101,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|
|
if stop:
|
|
|
extra_model_kwargs['stop_sequences'] = stop
|
|
|
|
|
|
- client = ZhipuModelAPI(
|
|
|
+ client = ZhipuAI(
|
|
|
api_key=credentials_kwargs['api_key']
|
|
|
)
|
|
|
|
|
@@ -128,11 +132,17 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|
|
# not support image message
|
|
|
continue
|
|
|
|
|
|
- if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER:
|
|
|
+ if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER and \
|
|
|
+ copy_prompt_message.role == PromptMessageRole.USER:
|
|
|
new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
|
|
|
else:
|
|
|
if copy_prompt_message.role == PromptMessageRole.USER:
|
|
|
new_prompt_messages.append(copy_prompt_message)
|
|
|
+ elif copy_prompt_message.role == PromptMessageRole.TOOL:
|
|
|
+ new_prompt_messages.append(copy_prompt_message)
|
|
|
+ elif copy_prompt_message.role == PromptMessageRole.SYSTEM:
|
|
|
+ new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content)
|
|
|
+ new_prompt_messages.append(new_prompt_message)
|
|
|
else:
|
|
|
new_prompt_message = UserPromptMessage(content=copy_prompt_message.content)
|
|
|
new_prompt_messages.append(new_prompt_message)
|
|
@@ -145,7 +155,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|
|
if model == 'glm-4v':
|
|
|
params = {
|
|
|
'model': model,
|
|
|
- 'prompt': [{
|
|
|
+ 'messages': [{
|
|
|
'role': prompt_message.role.value,
|
|
|
'content':
|
|
|
[
|
|
@@ -171,23 +181,63 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|
|
else:
|
|
|
params = {
|
|
|
'model': model,
|
|
|
- 'prompt': [{
|
|
|
- 'role': prompt_message.role.value,
|
|
|
- 'content': prompt_message.content,
|
|
|
- } for prompt_message in new_prompt_messages],
|
|
|
+ 'messages': [],
|
|
|
**model_parameters
|
|
|
}
|
|
|
+ # glm model
|
|
|
+ if not model.startswith('chatglm'):
|
|
|
+
|
|
|
+ for prompt_message in new_prompt_messages:
|
|
|
+ if prompt_message.role == PromptMessageRole.TOOL:
|
|
|
+ params['messages'].append({
|
|
|
+ 'role': 'tool',
|
|
|
+ 'content': prompt_message.content,
|
|
|
+ 'tool_call_id': prompt_message.tool_call_id
|
|
|
+ })
|
|
|
+ else:
|
|
|
+ params['messages'].append({
|
|
|
+ 'role': prompt_message.role.value,
|
|
|
+ 'content': prompt_message.content
|
|
|
+ })
|
|
|
+ else:
|
|
|
+ # chatglm model
|
|
|
+ for prompt_message in new_prompt_messages:
|
|
|
+ # merge system message to user message
|
|
|
+ if prompt_message.role == PromptMessageRole.SYSTEM or \
|
|
|
+ prompt_message.role == PromptMessageRole.TOOL or \
|
|
|
+ prompt_message.role == PromptMessageRole.USER:
|
|
|
+ if len(params['messages']) > 0 and params['messages'][-1]['role'] == 'user':
|
|
|
+ params['messages'][-1]['content'] += "\n\n" + prompt_message.content
|
|
|
+ else:
|
|
|
+ params['messages'].append({
|
|
|
+ 'role': 'user',
|
|
|
+ 'content': prompt_message.content
|
|
|
+ })
|
|
|
+ else:
|
|
|
+ params['messages'].append({
|
|
|
+ 'role': prompt_message.role.value,
|
|
|
+ 'content': prompt_message.content
|
|
|
+ })
|
|
|
+
|
|
|
+ if tools and len(tools) > 0:
|
|
|
+ params['tools'] = [
|
|
|
+ {
|
|
|
+ 'type': 'function',
|
|
|
+ 'function': helper.dump_model(tool)
|
|
|
+ } for tool in tools
|
|
|
+ ]
|
|
|
|
|
|
if stream:
|
|
|
- response = client.sse_invoke(incremental=True, **params).events()
|
|
|
- return self._handle_generate_stream_response(model, credentials_kwargs, response, prompt_messages)
|
|
|
+ response = client.chat.completions.create(stream=stream, **params)
|
|
|
+ return self._handle_generate_stream_response(model, credentials_kwargs, tools, response, prompt_messages)
|
|
|
|
|
|
- response = client.invoke(**params)
|
|
|
- return self._handle_generate_response(model, credentials_kwargs, response, prompt_messages)
|
|
|
+ response = client.chat.completions.create(**params)
|
|
|
+ return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages)
|
|
|
|
|
|
def _handle_generate_response(self, model: str,
|
|
|
credentials: dict,
|
|
|
- response: Dict[str, Any],
|
|
|
+ tools: Optional[list[PromptMessageTool]],
|
|
|
+ response: Completion,
|
|
|
prompt_messages: list[PromptMessage]) -> LLMResult:
|
|
|
"""
|
|
|
Handle llm response
|
|
@@ -197,26 +247,39 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|
|
:param prompt_messages: prompt messages
|
|
|
:return: llm response
|
|
|
"""
|
|
|
- data = response["data"]
|
|
|
text = ''
|
|
|
- for res in data["choices"]:
|
|
|
- text += res['content']
|
|
|
+ assistant_tool_calls: List[AssistantPromptMessage.ToolCall] = []
|
|
|
+ for choice in response.choices:
|
|
|
+ if choice.message.tool_calls:
|
|
|
+ for tool_call in choice.message.tool_calls:
|
|
|
+ if tool_call.type == 'function':
|
|
|
+ assistant_tool_calls.append(
|
|
|
+ AssistantPromptMessage.ToolCall(
|
|
|
+ id=tool_call.id,
|
|
|
+ type=tool_call.type,
|
|
|
+ function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
|
+ name=tool_call.function.name,
|
|
|
+ arguments=tool_call.function.arguments,
|
|
|
+ )
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ text += choice.message.content or ''
|
|
|
|
|
|
- token_usage = data.get("usage")
|
|
|
- if token_usage is not None:
|
|
|
- if 'prompt_tokens' not in token_usage:
|
|
|
- token_usage['prompt_tokens'] = 0
|
|
|
- if 'completion_tokens' not in token_usage:
|
|
|
- token_usage['completion_tokens'] = token_usage['total_tokens']
|
|
|
+ prompt_usage = response.usage.prompt_tokens
|
|
|
+ completion_usage = response.usage.completion_tokens
|
|
|
|
|
|
# transform usage
|
|
|
- usage = self._calc_response_usage(model, credentials, token_usage['prompt_tokens'], token_usage['completion_tokens'])
|
|
|
+ usage = self._calc_response_usage(model, credentials, prompt_usage, completion_usage)
|
|
|
|
|
|
# transform response
|
|
|
result = LLMResult(
|
|
|
model=model,
|
|
|
prompt_messages=prompt_messages,
|
|
|
- message=AssistantPromptMessage(content=text),
|
|
|
+ message=AssistantPromptMessage(
|
|
|
+ content=text,
|
|
|
+ tool_calls=assistant_tool_calls
|
|
|
+ ),
|
|
|
usage=usage,
|
|
|
)
|
|
|
|
|
@@ -224,7 +287,8 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|
|
|
|
|
def _handle_generate_stream_response(self, model: str,
|
|
|
credentials: dict,
|
|
|
- responses: list[Generator],
|
|
|
+ tools: Optional[list[PromptMessageTool]],
|
|
|
+ responses: Generator[ChatCompletionChunk, None, None],
|
|
|
prompt_messages: list[PromptMessage]) -> Generator:
|
|
|
"""
|
|
|
Handle llm stream response
|
|
@@ -234,39 +298,64 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|
|
:param prompt_messages: prompt messages
|
|
|
:return: llm response chunk generator result
|
|
|
"""
|
|
|
- for index, event in enumerate(responses):
|
|
|
- if event.event == "add":
|
|
|
+ full_assistant_content = ''
|
|
|
+ for chunk in responses:
|
|
|
+ if len(chunk.choices) == 0:
|
|
|
+ continue
|
|
|
+
|
|
|
+ delta = chunk.choices[0]
|
|
|
+
|
|
|
+ if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
|
|
|
+ continue
|
|
|
+
|
|
|
+ assistant_tool_calls: List[AssistantPromptMessage.ToolCall] = []
|
|
|
+ for tool_call in delta.delta.tool_calls or []:
|
|
|
+ if tool_call.type == 'function':
|
|
|
+ assistant_tool_calls.append(
|
|
|
+ AssistantPromptMessage.ToolCall(
|
|
|
+ id=tool_call.id,
|
|
|
+ type=tool_call.type,
|
|
|
+ function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
|
+ name=tool_call.function.name,
|
|
|
+ arguments=tool_call.function.arguments,
|
|
|
+ )
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ # transform assistant message to prompt message
|
|
|
+ assistant_prompt_message = AssistantPromptMessage(
|
|
|
+ content=delta.delta.content if delta.delta.content else '',
|
|
|
+ tool_calls=assistant_tool_calls
|
|
|
+ )
|
|
|
+
|
|
|
+ full_assistant_content += delta.delta.content if delta.delta.content else ''
|
|
|
+
|
|
|
+ if delta.finish_reason is not None and chunk.usage is not None:
|
|
|
+ completion_tokens = chunk.usage.completion_tokens
|
|
|
+ prompt_tokens = chunk.usage.prompt_tokens
|
|
|
+
|
|
|
+ # transform usage
|
|
|
+ usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
|
|
+
|
|
|
yield LLMResultChunk(
|
|
|
+ model=chunk.model,
|
|
|
prompt_messages=prompt_messages,
|
|
|
- model=model,
|
|
|
+ system_fingerprint='',
|
|
|
delta=LLMResultChunkDelta(
|
|
|
- index=index,
|
|
|
- message=AssistantPromptMessage(content=event.data)
|
|
|
+ index=delta.index,
|
|
|
+ message=assistant_prompt_message,
|
|
|
+ finish_reason=delta.finish_reason,
|
|
|
+ usage=usage
|
|
|
)
|
|
|
)
|
|
|
- elif event.event == "error" or event.event == "interrupted":
|
|
|
- raise ValueError(
|
|
|
- f"{event.data}"
|
|
|
- )
|
|
|
- elif event.event == "finish":
|
|
|
- meta = json.loads(event.meta)
|
|
|
- token_usage = meta['usage']
|
|
|
- if token_usage is not None:
|
|
|
- if 'prompt_tokens' not in token_usage:
|
|
|
- token_usage['prompt_tokens'] = 0
|
|
|
- if 'completion_tokens' not in token_usage:
|
|
|
- token_usage['completion_tokens'] = token_usage['total_tokens']
|
|
|
-
|
|
|
- usage = self._calc_response_usage(model, credentials, token_usage['prompt_tokens'], token_usage['completion_tokens'])
|
|
|
-
|
|
|
+ else:
|
|
|
yield LLMResultChunk(
|
|
|
- model=model,
|
|
|
+ model=chunk.model,
|
|
|
prompt_messages=prompt_messages,
|
|
|
+ system_fingerprint='',
|
|
|
delta=LLMResultChunkDelta(
|
|
|
- index=index,
|
|
|
- message=AssistantPromptMessage(content=event.data),
|
|
|
- finish_reason='finish',
|
|
|
- usage=usage
|
|
|
+ index=delta.index,
|
|
|
+ message=assistant_prompt_message,
|
|
|
)
|
|
|
)
|
|
|
|
|
@@ -291,11 +380,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|
|
raise ValueError(f"Got unknown type {message}")
|
|
|
|
|
|
return message_text
|
|
|
-
|
|
|
- def _convert_messages_to_prompt(self, messages: List[PromptMessage]) -> str:
|
|
|
- """
|
|
|
- Format a list of messages into a full prompt for the Anthropic model
|
|
|
|
|
|
+
|
|
|
+ def _convert_messages_to_prompt(self, messages: List[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> str:
|
|
|
+ """
|
|
|
:param messages: List of PromptMessage to combine.
|
|
|
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
|
|
"""
|
|
@@ -306,5 +394,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|
|
for message in messages
|
|
|
)
|
|
|
|
|
|
+ if tools and len(tools) > 0:
|
|
|
+ text += "\n\nTools:"
|
|
|
+ for tool in tools:
|
|
|
+ text += f"\n{tool.json()}"
|
|
|
+
|
|
|
# trim off the trailing ' ' that might come from the "Assistant: "
|
|
|
- return text.rstrip()
|
|
|
+ return text.rstrip()
|