Bladeren bron

feat: chatglm3 support (#1616)

takatost 1 jaar geleden
bovenliggende
commit
ea526d0822

+ 53 - 14
api/core/model_providers/models/llm/chatglm_model.py

@@ -1,27 +1,45 @@
-import decimal
+import logging
 from typing import List, Optional, Any
 
+import openai
 from langchain.callbacks.manager import Callbacks
-from langchain.llms import ChatGLM
-from langchain.schema import LLMResult
+from langchain.schema import LLMResult, get_buffer_string
 
-from core.model_providers.error import LLMBadRequestError
+from core.model_providers.error import LLMBadRequestError, LLMRateLimitError, LLMAuthorizationError, \
+    LLMAPIUnavailableError, LLMAPIConnectionError
 from core.model_providers.models.llm.base import BaseLLM
 from core.model_providers.models.entity.message import PromptMessage, MessageType
 from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
+from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
 
 
 class ChatGLMModel(BaseLLM):
-    model_mode: ModelMode = ModelMode.COMPLETION
+    model_mode: ModelMode = ModelMode.CHAT
 
     def _init_client(self) -> Any:
         provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
-        return ChatGLM(
+
+        extra_model_kwargs = {
+            'top_p': provider_model_kwargs.get('top_p')
+        }
+
+        if provider_model_kwargs.get('max_length') is not None:
+            extra_model_kwargs['max_length'] = provider_model_kwargs.get('max_length')
+
+        client = EnhanceChatOpenAI(
+            model_name=self.name,
+            temperature=provider_model_kwargs.get('temperature'),
+            max_tokens=provider_model_kwargs.get('max_tokens'),
+            model_kwargs=extra_model_kwargs,
+            streaming=self.streaming,
             callbacks=self.callbacks,
-            endpoint_url=self.credentials.get('api_base'),
-            **provider_model_kwargs
+            request_timeout=60,
+            openai_api_key="1",
+            openai_api_base=self.credentials['api_base'] + '/v1'
         )
 
+        return client
+
     def _run(self, messages: List[PromptMessage],
              stop: Optional[List[str]] = None,
              callbacks: Callbacks = None,
@@ -45,19 +63,40 @@ class ChatGLMModel(BaseLLM):
         :return:
         """
         prompts = self._get_prompt_from_messages(messages)
-        return max(self._client.get_num_tokens(prompts), 0)
+        return max(sum([self._client.get_num_tokens(get_buffer_string([m])) for m in prompts]) - len(prompts), 0)
 
     def get_currency(self):
         return 'RMB'
 
     def _set_model_kwargs(self, model_kwargs: ModelKwargs):
         provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
-        for k, v in provider_model_kwargs.items():
-            if hasattr(self.client, k):
-                setattr(self.client, k, v)
+        extra_model_kwargs = {
+            'top_p': provider_model_kwargs.get('top_p')
+        }
+
+        self.client.temperature = provider_model_kwargs.get('temperature')
+        self.client.max_tokens = provider_model_kwargs.get('max_tokens')
+        self.client.model_kwargs = extra_model_kwargs
 
     def handle_exceptions(self, ex: Exception) -> Exception:
-        if isinstance(ex, ValueError):
-            return LLMBadRequestError(f"ChatGLM: {str(ex)}")
+        if isinstance(ex, openai.error.InvalidRequestError):
+            logging.warning("Invalid request to ChatGLM API.")
+            return LLMBadRequestError(str(ex))
+        elif isinstance(ex, openai.error.APIConnectionError):
+            logging.warning("Failed to connect to ChatGLM API.")
+            return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
+        elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
+            logging.warning("ChatGLM service unavailable.")
+            return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
+        elif isinstance(ex, openai.error.RateLimitError):
+            return LLMRateLimitError(str(ex))
+        elif isinstance(ex, openai.error.AuthenticationError):
+            return LLMAuthorizationError(str(ex))
+        elif isinstance(ex, openai.error.OpenAIError):
+            return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
         else:
             return ex
+
+    @classmethod
+    def support_streaming(cls):
+        return True

+ 22 - 19
api/core/model_providers/providers/chatglm_provider.py

@@ -2,6 +2,7 @@ import json
 from json import JSONDecodeError
 from typing import Type
 
+import requests
 from langchain.llms import ChatGLM
 
 from core.helper import encrypter
@@ -25,21 +26,26 @@ class ChatGLMProvider(BaseModelProvider):
         if model_type == ModelType.TEXT_GENERATION:
             return [
                 {
-                    'id': 'chatglm2-6b',
-                    'name': 'ChatGLM2-6B',
-                    'mode': ModelMode.COMPLETION.value,
+                    'id': 'chatglm3-6b',
+                    'name': 'ChatGLM3-6B',
+                    'mode': ModelMode.CHAT.value,
+                },
+                {
+                    'id': 'chatglm3-6b-32k',
+                    'name': 'ChatGLM3-6B-32K',
+                    'mode': ModelMode.CHAT.value,
                 },
                 {
-                    'id': 'chatglm-6b',
-                    'name': 'ChatGLM-6B',
-                    'mode': ModelMode.COMPLETION.value,
+                    'id': 'chatglm2-6b',
+                    'name': 'ChatGLM2-6B',
+                    'mode': ModelMode.CHAT.value,
                 }
             ]
         else:
             return []
 
     def _get_text_generation_model_mode(self, model_name) -> str:
-        return ModelMode.COMPLETION.value
+        return ModelMode.CHAT.value
 
     def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
         """
@@ -64,16 +70,19 @@ class ChatGLMProvider(BaseModelProvider):
         :return:
         """
         model_max_tokens = {
-            'chatglm-6b': 2000,
-            'chatglm2-6b': 32000,
+            'chatglm3-6b-32k': 32000,
+            'chatglm3-6b': 8000,
+            'chatglm2-6b': 8000,
         }
 
+        max_tokens_alias = 'max_length' if model_name == 'chatglm2-6b' else 'max_tokens'
+
         return ModelKwargsRules(
             temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
             top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
             presence_penalty=KwargRule[float](enabled=False),
             frequency_penalty=KwargRule[float](enabled=False),
-            max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048, precision=0),
+            max_tokens=KwargRule[int](alias=max_tokens_alias, min=10, max=model_max_tokens.get(model_name), default=2048, precision=0),
         )
 
     @classmethod
@@ -85,16 +94,10 @@ class ChatGLMProvider(BaseModelProvider):
             raise CredentialsValidateFailedError('ChatGLM Endpoint URL must be provided.')
 
         try:
-            credential_kwargs = {
-                'endpoint_url': credentials['api_base']
-            }
-
-            llm = ChatGLM(
-                max_token=10,
-                **credential_kwargs
-            )
+            response = requests.get(f"{credentials['api_base']}/v1/models", timeout=5)
 
-            llm("ping")
+            if response.status_code != 200:
+                raise Exception('ChatGLM Endpoint URL is invalid.')
         except Exception as ex:
             raise CredentialsValidateFailedError(str(ex))