|
@@ -1,4 +1,5 @@
|
|
|
import base64
|
|
|
+import json
|
|
|
import mimetypes
|
|
|
from collections.abc import Generator
|
|
|
from typing import Optional, Union, cast
|
|
@@ -15,6 +16,7 @@ from anthropic.types import (
|
|
|
MessageStreamEvent,
|
|
|
completion_create_params,
|
|
|
)
|
|
|
+from anthropic.types.beta.tools import ToolsBetaMessage
|
|
|
from httpx import Timeout
|
|
|
|
|
|
from core.model_runtime.callbacks.base_callback import Callback
|
|
@@ -27,6 +29,7 @@ from core.model_runtime.entities.message_entities import (
|
|
|
PromptMessageTool,
|
|
|
SystemPromptMessage,
|
|
|
TextPromptMessageContent,
|
|
|
+ ToolPromptMessage,
|
|
|
UserPromptMessage,
|
|
|
)
|
|
|
from core.model_runtime.errors.invoke import (
|
|
@@ -70,10 +73,11 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|
|
:return: full response or stream response chunk generator result
|
|
|
"""
|
|
|
|
|
|
- return self._chat_generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
|
|
+ return self._chat_generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
|
|
|
|
|
def _chat_generate(self, model: str, credentials: dict,
|
|
|
- prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
|
|
|
+ prompt_messages: list[PromptMessage], model_parameters: dict,
|
|
|
+ tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
|
|
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
|
|
"""
|
|
|
Invoke llm chat model
|
|
@@ -109,14 +113,26 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|
|
if system:
|
|
|
extra_model_kwargs['system'] = system
|
|
|
|
|
|
-
|
|
|
- response = client.messages.create(
|
|
|
- model=model,
|
|
|
- messages=prompt_message_dicts,
|
|
|
- stream=stream,
|
|
|
- **model_parameters,
|
|
|
- **extra_model_kwargs
|
|
|
- )
|
|
|
+ if tools:
|
|
|
+ extra_model_kwargs['tools'] = [
|
|
|
+ self._transform_tool_prompt(tool) for tool in tools
|
|
|
+ ]
|
|
|
+ response = client.beta.tools.messages.create(
|
|
|
+ model=model,
|
|
|
+ messages=prompt_message_dicts,
|
|
|
+ stream=stream,
|
|
|
+ **model_parameters,
|
|
|
+ **extra_model_kwargs
|
|
|
+ )
|
|
|
+ else:
|
|
|
+
|
|
|
+ response = client.messages.create(
|
|
|
+ model=model,
|
|
|
+ messages=prompt_message_dicts,
|
|
|
+ stream=stream,
|
|
|
+ **model_parameters,
|
|
|
+ **extra_model_kwargs
|
|
|
+ )
|
|
|
|
|
|
if stream:
|
|
|
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
|
|
@@ -148,6 +164,13 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
|
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
|
|
|
|
|
+ def _transform_tool_prompt(self, tool: PromptMessageTool) -> dict:
|
|
|
+ return {
|
|
|
+ 'name': tool.name,
|
|
|
+ 'description': tool.description,
|
|
|
+ 'input_schema': tool.parameters
|
|
|
+ }
|
|
|
+
|
|
|
def _transform_chat_json_prompts(self, model: str, credentials: dict,
|
|
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
|
|
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
|
@@ -193,7 +216,18 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|
|
prompt = self._convert_messages_to_prompt_anthropic(prompt_messages)
|
|
|
|
|
|
client = Anthropic(api_key="")
|
|
|
- return client.count_tokens(prompt)
|
|
|
+ tokens = client.count_tokens(prompt)
|
|
|
+
|
|
|
+ tool_call_inner_prompts_tokens_map = {
|
|
|
+ 'claude-3-opus-20240229': 395,
|
|
|
+ 'claude-3-haiku-20240307': 264,
|
|
|
+ 'claude-3-sonnet-20240229': 159
|
|
|
+ }
|
|
|
+
|
|
|
+ if model in tool_call_inner_prompts_tokens_map and tools:
|
|
|
+ tokens += tool_call_inner_prompts_tokens_map[model]
|
|
|
+
|
|
|
+ return tokens
|
|
|
|
|
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
|
|
"""
|
|
@@ -219,7 +253,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|
|
except Exception as ex:
|
|
|
raise CredentialsValidateFailedError(str(ex))
|
|
|
|
|
|
- def _handle_chat_generate_response(self, model: str, credentials: dict, response: Message,
|
|
|
+ def _handle_chat_generate_response(self, model: str, credentials: dict, response: Union[Message, ToolsBetaMessage],
|
|
|
prompt_messages: list[PromptMessage]) -> LLMResult:
|
|
|
"""
|
|
|
Handle llm chat response
|
|
@@ -232,9 +266,24 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|
|
"""
|
|
|
|
|
|
assistant_prompt_message = AssistantPromptMessage(
|
|
|
- content=response.content[0].text
|
|
|
+ content='',
|
|
|
+ tool_calls=[]
|
|
|
)
|
|
|
|
|
|
+ for content in response.content:
|
|
|
+ if content.type == 'text':
|
|
|
+ assistant_prompt_message.content += content.text
|
|
|
+ elif content.type == 'tool_use':
|
|
|
+ tool_call = AssistantPromptMessage.ToolCall(
|
|
|
+ id=content.id,
|
|
|
+ type='function',
|
|
|
+ function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
|
+ name=content.name,
|
|
|
+ arguments=json.dumps(content.input)
|
|
|
+ )
|
|
|
+ )
|
|
|
+ assistant_prompt_message.tool_calls.append(tool_call)
|
|
|
+
|
|
|
|
|
|
if response.usage:
|
|
|
|
|
@@ -356,68 +405,89 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|
|
prompt_message_dicts = []
|
|
|
for message in prompt_messages:
|
|
|
if not isinstance(message, SystemPromptMessage):
|
|
|
- prompt_message_dicts.append(self._convert_prompt_message_to_dict(message))
|
|
|
-
|
|
|
- return system, prompt_message_dicts
|
|
|
-
|
|
|
- def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
|
|
- """
|
|
|
- Convert PromptMessage to dict
|
|
|
- """
|
|
|
- if isinstance(message, UserPromptMessage):
|
|
|
- message = cast(UserPromptMessage, message)
|
|
|
- if isinstance(message.content, str):
|
|
|
- message_dict = {"role": "user", "content": message.content}
|
|
|
- else:
|
|
|
- sub_messages = []
|
|
|
- for message_content in message.content:
|
|
|
- if message_content.type == PromptMessageContentType.TEXT:
|
|
|
- message_content = cast(TextPromptMessageContent, message_content)
|
|
|
- sub_message_dict = {
|
|
|
+ if isinstance(message, UserPromptMessage):
|
|
|
+ message = cast(UserPromptMessage, message)
|
|
|
+ if isinstance(message.content, str):
|
|
|
+ message_dict = {"role": "user", "content": message.content}
|
|
|
+ prompt_message_dicts.append(message_dict)
|
|
|
+ else:
|
|
|
+ sub_messages = []
|
|
|
+ for message_content in message.content:
|
|
|
+ if message_content.type == PromptMessageContentType.TEXT:
|
|
|
+ message_content = cast(TextPromptMessageContent, message_content)
|
|
|
+ sub_message_dict = {
|
|
|
+ "type": "text",
|
|
|
+ "text": message_content.data
|
|
|
+ }
|
|
|
+ sub_messages.append(sub_message_dict)
|
|
|
+ elif message_content.type == PromptMessageContentType.IMAGE:
|
|
|
+ message_content = cast(ImagePromptMessageContent, message_content)
|
|
|
+ if not message_content.data.startswith("data:"):
|
|
|
+
|
|
|
+ try:
|
|
|
+ image_content = requests.get(message_content.data).content
|
|
|
+ mime_type, _ = mimetypes.guess_type(message_content.data)
|
|
|
+ base64_data = base64.b64encode(image_content).decode('utf-8')
|
|
|
+ except Exception as ex:
|
|
|
+ raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
|
|
+ else:
|
|
|
+ data_split = message_content.data.split(";base64,")
|
|
|
+ mime_type = data_split[0].replace("data:", "")
|
|
|
+ base64_data = data_split[1]
|
|
|
+
|
|
|
+ if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
|
|
|
+ raise ValueError(f"Unsupported image type {mime_type}, "
|
|
|
+ f"only support image/jpeg, image/png, image/gif, and image/webp")
|
|
|
+
|
|
|
+ sub_message_dict = {
|
|
|
+ "type": "image",
|
|
|
+ "source": {
|
|
|
+ "type": "base64",
|
|
|
+ "media_type": mime_type,
|
|
|
+ "data": base64_data
|
|
|
+ }
|
|
|
+ }
|
|
|
+ sub_messages.append(sub_message_dict)
|
|
|
+ prompt_message_dicts.append({"role": "user", "content": sub_messages})
|
|
|
+ elif isinstance(message, AssistantPromptMessage):
|
|
|
+ message = cast(AssistantPromptMessage, message)
|
|
|
+ content = []
|
|
|
+ if message.tool_calls:
|
|
|
+ for tool_call in message.tool_calls:
|
|
|
+ content.append({
|
|
|
+ "type": "tool_use",
|
|
|
+ "id": tool_call.id,
|
|
|
+ "name": tool_call.function.name,
|
|
|
+ "input": json.loads(tool_call.function.arguments)
|
|
|
+ })
|
|
|
+ if message.content:
|
|
|
+ content.append({
|
|
|
"type": "text",
|
|
|
- "text": message_content.data
|
|
|
- }
|
|
|
- sub_messages.append(sub_message_dict)
|
|
|
- elif message_content.type == PromptMessageContentType.IMAGE:
|
|
|
- message_content = cast(ImagePromptMessageContent, message_content)
|
|
|
- if not message_content.data.startswith("data:"):
|
|
|
-
|
|
|
- try:
|
|
|
- image_content = requests.get(message_content.data).content
|
|
|
- mime_type, _ = mimetypes.guess_type(message_content.data)
|
|
|
- base64_data = base64.b64encode(image_content).decode('utf-8')
|
|
|
- except Exception as ex:
|
|
|
- raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
|
|
- else:
|
|
|
- data_split = message_content.data.split(";base64,")
|
|
|
- mime_type = data_split[0].replace("data:", "")
|
|
|
- base64_data = data_split[1]
|
|
|
-
|
|
|
- if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
|
|
|
- raise ValueError(f"Unsupported image type {mime_type}, "
|
|
|
- f"only support image/jpeg, image/png, image/gif, and image/webp")
|
|
|
-
|
|
|
- sub_message_dict = {
|
|
|
- "type": "image",
|
|
|
- "source": {
|
|
|
- "type": "base64",
|
|
|
- "media_type": mime_type,
|
|
|
- "data": base64_data
|
|
|
- }
|
|
|
- }
|
|
|
- sub_messages.append(sub_message_dict)
|
|
|
-
|
|
|
- message_dict = {"role": "user", "content": sub_messages}
|
|
|
- elif isinstance(message, AssistantPromptMessage):
|
|
|
- message = cast(AssistantPromptMessage, message)
|
|
|
- message_dict = {"role": "assistant", "content": message.content}
|
|
|
- elif isinstance(message, SystemPromptMessage):
|
|
|
- message = cast(SystemPromptMessage, message)
|
|
|
- message_dict = {"role": "system", "content": message.content}
|
|
|
- else:
|
|
|
- raise ValueError(f"Got unknown type {message}")
|
|
|
+ "text": message.content
|
|
|
+ })
|
|
|
+
|
|
|
+ if prompt_message_dicts[-1]["role"] == "assistant":
|
|
|
+ prompt_message_dicts[-1]["content"].extend(content)
|
|
|
+ else:
|
|
|
+ prompt_message_dicts.append({
|
|
|
+ "role": "assistant",
|
|
|
+ "content": content
|
|
|
+ })
|
|
|
+ elif isinstance(message, ToolPromptMessage):
|
|
|
+ message = cast(ToolPromptMessage, message)
|
|
|
+ message_dict = {
|
|
|
+ "role": "user",
|
|
|
+ "content": [{
|
|
|
+ "type": "tool_result",
|
|
|
+ "tool_use_id": message.tool_call_id,
|
|
|
+ "content": message.content
|
|
|
+ }]
|
|
|
+ }
|
|
|
+ prompt_message_dicts.append(message_dict)
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Got unknown type {message}")
|
|
|
|
|
|
- return message_dict
|
|
|
+ return system, prompt_message_dicts
|
|
|
|
|
|
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
|
|
|
"""
|
|
@@ -453,6 +523,8 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|
|
message_text += f"{ai_prompt} [IMAGE]"
|
|
|
elif isinstance(message, SystemPromptMessage):
|
|
|
message_text = content
|
|
|
+ elif isinstance(message, ToolPromptMessage):
|
|
|
+ message_text = f"{human_prompt} {message.content}"
|
|
|
else:
|
|
|
raise ValueError(f"Got unknown type {message}")
|
|
|
|