Преглед на файлове

feat: add anthropic claude-2.1 support (#1591)

takatost преди 1 година
родител
ревизия
4a55d5729d

+ 17 - 3
api/core/model_providers/providers/anthropic_provider.py

@@ -32,9 +32,12 @@ class AnthropicProvider(BaseModelProvider):
         if model_type == ModelType.TEXT_GENERATION:
             return [
                 {
-                    'id': 'claude-instant-1',
-                    'name': 'claude-instant-1',
+                    'id': 'claude-2.1',
+                    'name': 'claude-2.1',
                     'mode': ModelMode.CHAT.value,
+                    'features': [
+                        ModelFeature.AGENT_THOUGHT.value
+                    ]
                 },
                 {
                     'id': 'claude-2',
@@ -44,6 +47,11 @@ class AnthropicProvider(BaseModelProvider):
                         ModelFeature.AGENT_THOUGHT.value
                     ]
                 },
+                {
+                    'id': 'claude-instant-1',
+                    'name': 'claude-instant-1',
+                    'mode': ModelMode.CHAT.value,
+                },
             ]
         else:
             return []
@@ -73,12 +81,18 @@ class AnthropicProvider(BaseModelProvider):
         :param model_type:
         :return:
         """
+        model_max_tokens = {
+            'claude-instant-1': 100000,
+            'claude-2': 100000,
+            'claude-2.1': 200000,
+        }
+
         return ModelKwargsRules(
             temperature=KwargRule[float](min=0, max=1, default=1, precision=2),
             top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
             presence_penalty=KwargRule[float](enabled=False),
             frequency_penalty=KwargRule[float](enabled=False),
-            max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256, precision=0),
+            max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=model_max_tokens.get(model_name, 100000), default=256, precision=0),
         )
 
     @classmethod

+ 8 - 2
api/core/model_providers/rules/anthropic.json

@@ -23,8 +23,14 @@
             "currency": "USD"
         },
         "claude-2": {
-            "prompt": "11.02",
-            "completion": "32.68",
+            "prompt": "8.00",
+            "completion": "24.00",
+            "unit": "0.000001",
+            "currency": "USD"
+        },
+        "claude-2.1": {
+            "prompt": "8.00",
+            "completion": "24.00",
             "unit": "0.000001",
             "currency": "USD"
         }

+ 15 - 3
api/core/third_party/langchain/llms/anthropic_llm.py

@@ -1,7 +1,7 @@
 from typing import Dict
 
-from httpx import Limits
 from langchain.chat_models import ChatAnthropic
+from langchain.schema import ChatMessage, BaseMessage, HumanMessage, AIMessage, SystemMessage
 from langchain.utils import get_from_dict_or_env, check_package_version
 from pydantic import root_validator
 
@@ -29,8 +29,7 @@ class AnthropicLLM(ChatAnthropic):
                 base_url=values["anthropic_api_url"],
                 api_key=values["anthropic_api_key"],
                 timeout=values["default_request_timeout"],
-                max_retries=0,
-                connection_pool_limits=Limits(max_connections=200, max_keepalive_connections=100),
+                max_retries=0
             )
             values["async_client"] = anthropic.AsyncAnthropic(
                 base_url=values["anthropic_api_url"],
@@ -46,3 +45,16 @@ class AnthropicLLM(ChatAnthropic):
                 "Please it install it with `pip install anthropic`."
             )
         return values
+
+    def _convert_one_message_to_text(self, message: BaseMessage) -> str:
+        if isinstance(message, ChatMessage):
+            message_text = f"\n\n{message.role.capitalize()}: {message.content}"
+        elif isinstance(message, HumanMessage):
+            message_text = f"{self.HUMAN_PROMPT} {message.content}"
+        elif isinstance(message, AIMessage):
+            message_text = f"{self.AI_PROMPT} {message.content}"
+        elif isinstance(message, SystemMessage):
+            message_text = f"{message.content}"
+        else:
+            raise ValueError(f"Got unknown type {message}")
+        return message_text

+ 1 - 1
api/requirements.txt

@@ -35,7 +35,7 @@ docx2txt==0.8
 pypdfium2==4.16.0
 resend~=0.5.1
 pyjwt~=2.6.0
-anthropic~=0.3.4
+anthropic~=0.7.2
 newspaper3k==0.2.8
 google-api-python-client==2.90.0
 wikipedia==1.4.0

+ 4 - 4
api/tests/unit_tests/model_providers/test_anthropic_provider.py

@@ -31,12 +31,12 @@ def mock_chat_generate_invalid(messages: List[BaseMessage],
                                run_manager: Optional[CallbackManagerForLLMRun] = None,
                                **kwargs: Any):
     raise anthropic.APIStatusError('Invalid credentials',
-                                   request=httpx._models.Request(
-                                       method='POST',
-                                       url='https://api.anthropic.com/v1/completions',
-                                   ),
                                    response=httpx._models.Response(
                                        status_code=401,
+                                       request=httpx._models.Request(
+                                           method='POST',
+                                           url='https://api.anthropic.com/v1/completions',
+                                       )
                                    ),
                                    body=None
                                 )