Browse Source

feat: claude3 tool call (#3111)

Yeuoly 1 year ago
parent
commit
25b9ac3df4

+ 1 - 0
api/core/model_runtime/model_providers/anthropic/llm/claude-3-haiku-20240307.yaml

@@ -5,6 +5,7 @@ model_type: llm
 features:
   - agent-thought
   - vision
+  - tool-call
 model_properties:
   mode: chat
   context_size: 200000

+ 1 - 0
api/core/model_runtime/model_providers/anthropic/llm/claude-3-opus-20240229.yaml

@@ -5,6 +5,7 @@ model_type: llm
 features:
   - agent-thought
   - vision
+  - tool-call
 model_properties:
   mode: chat
   context_size: 200000

+ 1 - 0
api/core/model_runtime/model_providers/anthropic/llm/claude-3-sonnet-20240229.yaml

@@ -5,6 +5,7 @@ model_type: llm
 features:
   - agent-thought
   - vision
+  - tool-call
 model_properties:
   mode: chat
   context_size: 200000

+ 145 - 73
api/core/model_runtime/model_providers/anthropic/llm/llm.py

@@ -1,4 +1,5 @@
 import base64
+import json
 import mimetypes
 from collections.abc import Generator
 from typing import Optional, Union, cast
@@ -15,6 +16,7 @@ from anthropic.types import (
     MessageStreamEvent,
     completion_create_params,
 )
+from anthropic.types.beta.tools import ToolsBetaMessage
 from httpx import Timeout
 
 from core.model_runtime.callbacks.base_callback import Callback
@@ -27,6 +29,7 @@ from core.model_runtime.entities.message_entities import (
     PromptMessageTool,
     SystemPromptMessage,
     TextPromptMessageContent,
+    ToolPromptMessage,
     UserPromptMessage,
 )
 from core.model_runtime.errors.invoke import (
@@ -70,10 +73,11 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
         :return: full response or stream response chunk generator result
         """
         # invoke model
-        return self._chat_generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
+        return self._chat_generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
 
     def _chat_generate(self, model: str, credentials: dict,
-                       prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
+                       prompt_messages: list[PromptMessage], model_parameters: dict, 
+                       tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
                        stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
         """
         Invoke llm chat model
@@ -109,14 +113,26 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
         if system:
             extra_model_kwargs['system'] = system
 
-        # chat model
-        response = client.messages.create(
-            model=model,
-            messages=prompt_message_dicts,
-            stream=stream,
-            **model_parameters,
-            **extra_model_kwargs
-        )
+        if tools:
+            extra_model_kwargs['tools'] = [
+                self._transform_tool_prompt(tool) for tool in tools
+            ]
+            response = client.beta.tools.messages.create(
+                model=model,
+                messages=prompt_message_dicts,
+                stream=stream,
+                **model_parameters,
+                **extra_model_kwargs
+            )
+        else:
+            # chat model
+            response = client.messages.create(
+                model=model,
+                messages=prompt_message_dicts,
+                stream=stream,
+                **model_parameters,
+                **extra_model_kwargs
+            )
 
         if stream:
             return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
@@ -148,6 +164,13 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
 
         return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
 
+    def _transform_tool_prompt(self, tool: PromptMessageTool) -> dict:
+        return {
+            'name': tool.name,
+            'description': tool.description,
+            'input_schema': tool.parameters
+        }
+
     def _transform_chat_json_prompts(self, model: str, credentials: dict,
                                      prompt_messages: list[PromptMessage], model_parameters: dict,
                                      tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
@@ -193,7 +216,18 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
         prompt = self._convert_messages_to_prompt_anthropic(prompt_messages)
 
         client = Anthropic(api_key="")
-        return client.count_tokens(prompt)
+        tokens = client.count_tokens(prompt)
+
+        tool_call_inner_prompts_tokens_map = {
+            'claude-3-opus-20240229': 395,
+            'claude-3-haiku-20240307': 264,
+            'claude-3-sonnet-20240229': 159
+        }
+
+        if model in tool_call_inner_prompts_tokens_map and tools:
+            tokens += tool_call_inner_prompts_tokens_map[model]
+
+        return tokens
 
     def validate_credentials(self, model: str, credentials: dict) -> None:
         """
@@ -219,7 +253,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
         except Exception as ex:
             raise CredentialsValidateFailedError(str(ex))
 
-    def _handle_chat_generate_response(self, model: str, credentials: dict, response: Message,
+    def _handle_chat_generate_response(self, model: str, credentials: dict, response: Union[Message, ToolsBetaMessage],
                                        prompt_messages: list[PromptMessage]) -> LLMResult:
         """
         Handle llm chat response
@@ -232,9 +266,24 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
         """
         # transform assistant message to prompt message
         assistant_prompt_message = AssistantPromptMessage(
-            content=response.content[0].text
+            content='',
+            tool_calls=[]
         )
 
+        for content in response.content:
+            if content.type == 'text':
+                assistant_prompt_message.content += content.text
+            elif content.type == 'tool_use':
+                tool_call = AssistantPromptMessage.ToolCall(
+                    id=content.id,
+                    type='function',
+                    function=AssistantPromptMessage.ToolCall.ToolCallFunction(
+                        name=content.name,
+                        arguments=json.dumps(content.input)
+                    )
+                )
+                assistant_prompt_message.tool_calls.append(tool_call)
+
         # calculate num tokens
         if response.usage:
             # transform usage
@@ -356,68 +405,89 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
         prompt_message_dicts = []
         for message in prompt_messages:
             if not isinstance(message, SystemPromptMessage):
-                prompt_message_dicts.append(self._convert_prompt_message_to_dict(message))
-
-        return system, prompt_message_dicts
-
-    def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
-        """
-        Convert PromptMessage to dict
-        """
-        if isinstance(message, UserPromptMessage):
-            message = cast(UserPromptMessage, message)
-            if isinstance(message.content, str):
-                message_dict = {"role": "user", "content": message.content}
-            else:
-                sub_messages = []
-                for message_content in message.content:
-                    if message_content.type == PromptMessageContentType.TEXT:
-                        message_content = cast(TextPromptMessageContent, message_content)
-                        sub_message_dict = {
+                if isinstance(message, UserPromptMessage):
+                    message = cast(UserPromptMessage, message)
+                    if isinstance(message.content, str):
+                        message_dict = {"role": "user", "content": message.content}
+                        prompt_message_dicts.append(message_dict)
+                    else:
+                        sub_messages = []
+                        for message_content in message.content:
+                            if message_content.type == PromptMessageContentType.TEXT:
+                                message_content = cast(TextPromptMessageContent, message_content)
+                                sub_message_dict = {
+                                    "type": "text",
+                                    "text": message_content.data
+                                }
+                                sub_messages.append(sub_message_dict)
+                            elif message_content.type == PromptMessageContentType.IMAGE:
+                                message_content = cast(ImagePromptMessageContent, message_content)
+                                if not message_content.data.startswith("data:"):
+                                    # fetch image data from url
+                                    try:
+                                        image_content = requests.get(message_content.data).content
+                                        mime_type, _ = mimetypes.guess_type(message_content.data)
+                                        base64_data = base64.b64encode(image_content).decode('utf-8')
+                                    except Exception as ex:
+                                        raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
+                                else:
+                                    data_split = message_content.data.split(";base64,")
+                                    mime_type = data_split[0].replace("data:", "")
+                                    base64_data = data_split[1]
+
+                                if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
+                                    raise ValueError(f"Unsupported image type {mime_type}, "
+                                                    f"only support image/jpeg, image/png, image/gif, and image/webp")
+
+                                sub_message_dict = {
+                                    "type": "image",
+                                    "source": {
+                                        "type": "base64",
+                                        "media_type": mime_type,
+                                        "data": base64_data
+                                    }
+                                }
+                                sub_messages.append(sub_message_dict)
+                        prompt_message_dicts.append({"role": "user", "content": sub_messages})
+                elif isinstance(message, AssistantPromptMessage):
+                    message = cast(AssistantPromptMessage, message)
+                    content = []
+                    if message.tool_calls:
+                        for tool_call in message.tool_calls:
+                            content.append({
+                                "type": "tool_use",
+                                "id": tool_call.id,
+                                "name": tool_call.function.name,
+                                "input": json.loads(tool_call.function.arguments)
+                            })
+                    if message.content:
+                        content.append({
                             "type": "text",
-                            "text": message_content.data
-                        }
-                        sub_messages.append(sub_message_dict)
-                    elif message_content.type == PromptMessageContentType.IMAGE:
-                        message_content = cast(ImagePromptMessageContent, message_content)
-                        if not message_content.data.startswith("data:"):
-                            # fetch image data from url
-                            try:
-                                image_content = requests.get(message_content.data).content
-                                mime_type, _ = mimetypes.guess_type(message_content.data)
-                                base64_data = base64.b64encode(image_content).decode('utf-8')
-                            except Exception as ex:
-                                raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
-                        else:
-                            data_split = message_content.data.split(";base64,")
-                            mime_type = data_split[0].replace("data:", "")
-                            base64_data = data_split[1]
-
-                        if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
-                            raise ValueError(f"Unsupported image type {mime_type}, "
-                                             f"only support image/jpeg, image/png, image/gif, and image/webp")
-
-                        sub_message_dict = {
-                            "type": "image",
-                            "source": {
-                                "type": "base64",
-                                "media_type": mime_type,
-                                "data": base64_data
-                            }
-                        }
-                        sub_messages.append(sub_message_dict)
-
-                message_dict = {"role": "user", "content": sub_messages}
-        elif isinstance(message, AssistantPromptMessage):
-            message = cast(AssistantPromptMessage, message)
-            message_dict = {"role": "assistant", "content": message.content}
-        elif isinstance(message, SystemPromptMessage):
-            message = cast(SystemPromptMessage, message)
-            message_dict = {"role": "system", "content": message.content}
-        else:
-            raise ValueError(f"Got unknown type {message}")
+                            "text": message.content
+                        })
+                    
+                    if prompt_message_dicts[-1]["role"] == "assistant":
+                        prompt_message_dicts[-1]["content"].extend(content)
+                    else:
+                        prompt_message_dicts.append({
+                            "role": "assistant",
+                            "content": content
+                        })
+                elif isinstance(message, ToolPromptMessage):
+                    message = cast(ToolPromptMessage, message)
+                    message_dict = {
+                        "role": "user",
+                        "content": [{
+                            "type": "tool_result",
+                            "tool_use_id": message.tool_call_id,
+                            "content": message.content
+                        }]
+                    }
+                    prompt_message_dicts.append(message_dict)
+                else:
+                    raise ValueError(f"Got unknown type {message}")
 
-        return message_dict
+        return system, prompt_message_dicts
 
     def _convert_one_message_to_text(self, message: PromptMessage) -> str:
         """
@@ -453,6 +523,8 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
                         message_text += f"{ai_prompt} [IMAGE]"
         elif isinstance(message, SystemPromptMessage):
             message_text = content
+        elif isinstance(message, ToolPromptMessage):
+            message_text = f"{human_prompt} {message.content}"
         else:
             raise ValueError(f"Got unknown type {message}")
 

+ 1 - 1
api/requirements.txt

@@ -36,7 +36,7 @@ python-docx~=1.1.0
 pypdfium2==4.16.0
 resend~=0.7.0
 pyjwt~=2.8.0
-anthropic~=0.20.0
+anthropic~=0.23.1
 newspaper3k==0.2.8
 google-api-python-client==2.90.0
 wikipedia==1.4.0