Browse Source

feat: support weixin ernie-bot-4 and chat mode (#1375)

takatost 1 năm trước cách đây
mục cha
commit
7c9b585a47

+ 14 - 5
api/core/model_providers/models/llm/wenxin_model.py

@@ -6,17 +6,16 @@ from langchain.schema import LLMResult
 
 from core.model_providers.error import LLMBadRequestError
 from core.model_providers.models.llm.base import BaseLLM
-from core.model_providers.models.entity.message import PromptMessage, MessageType
+from core.model_providers.models.entity.message import PromptMessage
 from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
 from core.third_party.langchain.llms.wenxin import Wenxin
 
 
 class WenxinModel(BaseLLM):
-    model_mode: ModelMode = ModelMode.COMPLETION
+    model_mode: ModelMode = ModelMode.CHAT
 
     def _init_client(self) -> Any:
         provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
-        # TODO load price_config from configs(db)
         return Wenxin(
             model=self.name,
             streaming=self.streaming,
@@ -38,7 +37,13 @@ class WenxinModel(BaseLLM):
         :return:
         """
         prompts = self._get_prompt_from_messages(messages)
-        return self._client.generate([prompts], stop, callbacks)
+
+        generate_kwargs = {'stop': stop, 'callbacks': callbacks, 'messages': [prompts]}
+
+        if 'functions' in kwargs:
+            generate_kwargs['functions'] = kwargs['functions']
+
+        return self._client.generate(**generate_kwargs)
 
     def get_num_tokens(self, messages: List[PromptMessage]) -> int:
         """
@@ -48,7 +53,7 @@ class WenxinModel(BaseLLM):
         :return:
         """
         prompts = self._get_prompt_from_messages(messages)
-        return max(self._client.get_num_tokens(prompts), 0)
+        return max(self._client.get_num_tokens_from_messages(prompts), 0)
 
     def _set_model_kwargs(self, model_kwargs: ModelKwargs):
         provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
@@ -58,3 +63,7 @@ class WenxinModel(BaseLLM):
 
     def handle_exceptions(self, ex: Exception) -> Exception:
         return LLMBadRequestError(f"Wenxin: {str(ex)}")
+
+    @property
+    def support_streaming(self):
+        return True

+ 13 - 5
api/core/model_providers/providers/wenxin_provider.py

@@ -2,6 +2,8 @@ import json
 from json import JSONDecodeError
 from typing import Type
 
+from langchain.schema import HumanMessage
+
 from core.helper import encrypter
 from core.model_providers.models.base import BaseProviderModel
 from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
@@ -23,20 +25,25 @@ class WenxinProvider(BaseModelProvider):
     def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
         if model_type == ModelType.TEXT_GENERATION:
             return [
+                {
+                    'id': 'ernie-bot-4',
+                    'name': 'ERNIE-Bot-4',
+                    'mode': ModelMode.CHAT.value,
+                },
                 {
                     'id': 'ernie-bot',
                     'name': 'ERNIE-Bot',
-                    'mode': ModelMode.COMPLETION.value,
+                    'mode': ModelMode.CHAT.value,
                 },
                 {
                     'id': 'ernie-bot-turbo',
                     'name': 'ERNIE-Bot-turbo',
-                    'mode': ModelMode.COMPLETION.value,
+                    'mode': ModelMode.CHAT.value,
                 },
                 {
                     'id': 'bloomz-7b',
                     'name': 'BLOOMZ-7B',
-                    'mode': ModelMode.COMPLETION.value,
+                    'mode': ModelMode.CHAT.value,
                 }
             ]
         else:
@@ -68,11 +75,12 @@ class WenxinProvider(BaseModelProvider):
         :return:
         """
         model_max_tokens = {
+            'ernie-bot-4': 4800,
             'ernie-bot': 4800,
             'ernie-bot-turbo': 11200,
         }
 
-        if model_name in ['ernie-bot', 'ernie-bot-turbo']:
+        if model_name in ['ernie-bot-4', 'ernie-bot', 'ernie-bot-turbo']:
             return ModelKwargsRules(
                 temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2),
                 top_p=KwargRule[float](min=0.01, max=1, default=0.8, precision=2),
@@ -111,7 +119,7 @@ class WenxinProvider(BaseModelProvider):
                 **credential_kwargs
             )
 
-            llm("ping")
+            llm([HumanMessage(content='ping')])
         except Exception as ex:
             raise CredentialsValidateFailedError(str(ex))
 

+ 6 - 0
api/core/model_providers/rules/wenxin.json

@@ -5,6 +5,12 @@
     "system_config": null,
     "model_flexibility": "fixed",
     "price_config": {
+        "ernie-bot-4": {
+            "prompt": "0",
+            "completion": "0",
+            "unit": "0.001",
+            "currency": "RMB"
+        },
         "ernie-bot": {
             "prompt": "0.012",
             "completion": "0.012",

+ 135 - 63
api/core/third_party/langchain/llms/wenxin.py

@@ -8,12 +8,15 @@ from typing import (
     Any,
     Dict,
     List,
-    Optional, Iterator,
+    Optional, Iterator, Tuple,
 )
 
 import requests
+from langchain.chat_models.base import BaseChatModel
 from langchain.llms.utils import enforce_stop_tokens
-from langchain.schema.output import GenerationChunk
+from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage
+from langchain.schema.messages import AIMessageChunk
+from langchain.schema.output import GenerationChunk, ChatResult, ChatGenerationChunk, ChatGeneration
 from pydantic import BaseModel, Extra, Field, PrivateAttr, root_validator
 
 from langchain.callbacks.manager import (
@@ -61,6 +64,7 @@ class _WenxinEndpointClient(BaseModel):
             raise ValueError(f"Wenxin Model name is required")
 
         model_url_map = {
+            'ernie-bot-4': 'completions_pro',
             'ernie-bot': 'completions',
             'ernie-bot-turbo': 'eb-instant',
             'bloomz-7b': 'bloomz_7b1',
@@ -70,6 +74,7 @@ class _WenxinEndpointClient(BaseModel):
 
         access_token = self.get_access_token()
         api_url = f"{self.base_url}{model_url_map[request['model']]}?access_token={access_token}"
+        del request['model']
 
         headers = {"Content-Type": "application/json"}
         response = requests.post(api_url,
@@ -86,22 +91,21 @@ class _WenxinEndpointClient(BaseModel):
                     f"Wenxin API {json_response['error_code']}"
                     f" error: {json_response['error_msg']}"
                 )
-            return json_response["result"]
+            return json_response
         else:
             return response
 
 
-class Wenxin(LLM):
-    """Wrapper around Wenxin large language models.
-    To use, you should have the environment variable
-    ``WENXIN_API_KEY`` and ``WENXIN_SECRET_KEY`` set with your API key,
-    or pass them as a named parameter to the constructor.
-    Example:
-     .. code-block:: python
-         from langchain.llms.wenxin import Wenxin
-         wenxin = Wenxin(model="<model_name>", api_key="my-api-key",
-          secret_key="my-group-id")
-    """
+class Wenxin(BaseChatModel):
+    """Wrapper around Wenxin large language models."""
+
+    @property
+    def lc_secrets(self) -> Dict[str, str]:
+        return {"api_key": "API_KEY", "secret_key": "SECRET_KEY"}
+
+    @property
+    def lc_serializable(self) -> bool:
+        return True
 
     _client: _WenxinEndpointClient = PrivateAttr()
     model: str = "ernie-bot"
@@ -161,64 +165,89 @@ class Wenxin(LLM):
             secret_key=self.secret_key,
         )
 
-    def _call(
+    def _convert_message_to_dict(self, message: BaseMessage) -> dict:
+        if isinstance(message, ChatMessage):
+            message_dict = {"role": message.role, "content": message.content}
+        elif isinstance(message, HumanMessage):
+            message_dict = {"role": "user", "content": message.content}
+        elif isinstance(message, AIMessage):
+            message_dict = {"role": "assistant", "content": message.content}
+        elif isinstance(message, SystemMessage):
+            message_dict = {"role": "system", "content": message.content}
+        else:
+            raise ValueError(f"Got unknown type {message}")
+        return message_dict
+
+    def _create_message_dicts(
+        self, messages: List[BaseMessage]
+    ) -> Tuple[List[Dict[str, Any]], str]:
+        dict_messages = []
+        system = None
+        for m in messages:
+            message = self._convert_message_to_dict(m)
+            if message['role'] == 'system':
+                if not system:
+                    system = message['content']
+                else:
+                    system += f"\n{message['content']}"
+                continue
+
+            if dict_messages:
+                previous_message = dict_messages[-1]
+                if previous_message['role'] == message['role']:
+                    dict_messages[-1]['content'] += f"\n{message['content']}"
+                else:
+                    dict_messages.append(message)
+            else:
+                dict_messages.append(message)
+
+        return dict_messages, system
+
+    def _generate(
         self,
-        prompt: str,
+        messages: List[BaseMessage],
         stop: Optional[List[str]] = None,
         run_manager: Optional[CallbackManagerForLLMRun] = None,
         **kwargs: Any,
-    ) -> str:
-        r"""Call out to Wenxin's completion endpoint to chat
-        Args:
-            prompt: The prompt to pass into the model.
-        Returns:
-            The string generated by the model.
-        Example:
-            .. code-block:: python
-                response = wenxin("Tell me a joke.")
-        """
+    ) -> ChatResult:
         if self.streaming:
-            completion = ""
+            generation: Optional[ChatGenerationChunk] = None
+            llm_output: Optional[Dict] = None
             for chunk in self._stream(
-                prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
+                    messages=messages, stop=stop, run_manager=run_manager, **kwargs
             ):
-                completion += chunk.text
+                if chunk.generation_info is not None \
+                        and 'token_usage' in chunk.generation_info:
+                    llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model}
+
+                if generation is None:
+                    generation = chunk
+                else:
+                    generation += chunk
+            assert generation is not None
+            return ChatResult(generations=[generation], llm_output=llm_output)
         else:
+            message_dicts, system = self._create_message_dicts(messages)
             request = self._default_params
-            request["messages"] = [{"role": "user", "content": prompt}]
+            request["messages"] = message_dicts
+            if system:
+                request["system"] = system
             request.update(kwargs)
-            completion = self._client.post(request)
-
-        if stop is not None:
-            completion = enforce_stop_tokens(completion, stop)
-
-        return completion
+            response = self._client.post(request)
+            return self._create_chat_result(response)
 
     def _stream(
-        self,
-        prompt: str,
-        stop: Optional[List[str]] = None,
-        run_manager: Optional[CallbackManagerForLLMRun] = None,
-        **kwargs: Any,
-    ) -> Iterator[GenerationChunk]:
-        r"""Call wenxin completion_stream and return the resulting generator.
-
-        Args:
-            prompt: The prompt to pass into the model.
-            stop: Optional list of stop words to use when generating.
-        Returns:
-            A generator representing the stream of tokens from Wenxin.
-        Example:
-            .. code-block:: python
-
-                prompt = "Write a poem about a stream."
-                prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
-                generator = wenxin.stream(prompt)
-                for token in generator:
-                    yield token
-        """
+            self,
+            messages: List[BaseMessage],
+            stop: Optional[List[str]] = None,
+            run_manager: Optional[CallbackManagerForLLMRun] = None,
+            **kwargs: Any,
+    ) -> Iterator[ChatGenerationChunk]:
+        message_dicts, system = self._create_message_dicts(messages)
         request = self._default_params
-        request["messages"] = [{"role": "user", "content": prompt}]
+        request["messages"] = message_dicts
+        if system:
+            request["system"] = system
         request.update(kwargs)
 
         for token in self._client.post(request).iter_lines():
@@ -228,12 +257,18 @@ class Wenxin(LLM):
                 if token.startswith('data:'):
                     completion = json.loads(token[5:])
 
-                    yield GenerationChunk(text=completion['result'])
-                    if run_manager:
-                        run_manager.on_llm_new_token(completion['result'])
+                    chunk_dict = {
+                        'message': AIMessageChunk(content=completion['result']),
+                    }
 
                     if completion['is_end']:
-                        break
+                        token_usage = completion['usage']
+                        token_usage['completion_tokens'] = token_usage['total_tokens'] - token_usage['prompt_tokens']
+                        chunk_dict['generation_info'] = dict({'token_usage': token_usage})
+
+                    yield ChatGenerationChunk(**chunk_dict)
+                    if run_manager:
+                        run_manager.on_llm_new_token(completion['result'])
                 else:
                     try:
                         json_response = json.loads(token)
@@ -245,3 +280,40 @@ class Wenxin(LLM):
                         f" error: {json_response['error_msg']}, "
                         f"please confirm if the model you have chosen is already paid for."
                     )
+
+    def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult:
+        generations = [ChatGeneration(
+            message=AIMessage(content=response['result']),
+        )]
+        token_usage = response.get("usage")
+        token_usage['completion_tokens'] = token_usage['total_tokens'] - token_usage['prompt_tokens']
+
+        llm_output = {"token_usage": token_usage, "model_name": self.model}
+        return ChatResult(generations=generations, llm_output=llm_output)
+
+    def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
+        """Get the number of tokens in the messages.
+
+        Useful for checking if an input will fit in a model's context window.
+
+        Args:
+            messages: The message inputs to tokenize.
+
+        Returns:
+            The sum of the number of tokens across the messages.
+        """
+        return sum([self.get_num_tokens(m.content) for m in messages])
+
+    def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
+        overall_token_usage: dict = {}
+        for output in llm_outputs:
+            if output is None:
+                # Happens in streaming
+                continue
+            token_usage = output["token_usage"]
+            for k, v in token_usage.items():
+                if k in overall_token_usage:
+                    overall_token_usage[k] += v
+                else:
+                    overall_token_usage[k] = v
+        return {"token_usage": overall_token_usage, "model_name": self.model}

+ 2 - 3
api/tests/integration_tests/models/llm/test_wenxin_model.py

@@ -56,9 +56,8 @@ def test_run(mock_decrypt, mocker):
     mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
 
     model = get_mock_model('ernie-bot')
-    messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: Integer answer is:')]
+    messages = [PromptMessage(type=MessageType.USER, content='Human: 1 + 1=? \nAssistant: Integer answer is:')]
     rst = model.run(
-        messages,
-        stop=['\nHuman:'],
+        messages
     )
     assert len(rst.content) > 0

+ 4 - 1
api/tests/unit_tests/model_providers/test_wenxin_provider.py

@@ -2,6 +2,8 @@ import pytest
 from unittest.mock import patch
 import json
 
+from langchain.schema import AIMessage, ChatGeneration, ChatResult
+
 from core.model_providers.providers.base import CredentialsValidateFailedError
 from core.model_providers.providers.wenxin_provider import WenxinProvider
 from models.provider import ProviderType, Provider
@@ -24,7 +26,8 @@ def decrypt_side_effect(tenant_id, encrypted_key):
 
 
 def test_is_provider_credentials_valid_or_raise_valid(mocker):
-    mocker.patch('core.third_party.langchain.llms.wenxin.Wenxin._call', return_value="abc")
+    mocker.patch('core.third_party.langchain.llms.wenxin.Wenxin._generate',
+                 return_value=ChatResult(generations=[ChatGeneration(message=AIMessage(content='abc'))]))
 
     MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)