浏览代码

feat: optimize minimax llm call (#1312)

takatost 1 年之前
父节点
当前提交
2851a9f04e

+ 10 - 9
api/core/model_providers/models/llm/minimax_model.py

@@ -1,26 +1,23 @@
-import decimal
 from typing import List, Optional, Any
 
 from langchain.callbacks.manager import Callbacks
-from langchain.llms import Minimax
 from langchain.schema import LLMResult
 
 from core.model_providers.error import LLMBadRequestError
 from core.model_providers.models.llm.base import BaseLLM
-from core.model_providers.models.entity.message import PromptMessage, MessageType
+from core.model_providers.models.entity.message import PromptMessage
 from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
+from core.third_party.langchain.llms.minimax_llm import MinimaxChatLLM
 
 
 class MinimaxModel(BaseLLM):
-    model_mode: ModelMode = ModelMode.COMPLETION
+    model_mode: ModelMode = ModelMode.CHAT
 
     def _init_client(self) -> Any:
         provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
-        return Minimax(
+        return MinimaxChatLLM(
             model=self.name,
-            model_kwargs={
-                'stream': False
-            },
+            streaming=self.streaming,
             callbacks=self.callbacks,
             **self.credentials,
             **provider_model_kwargs
@@ -49,7 +46,7 @@ class MinimaxModel(BaseLLM):
         :return:
         """
         prompts = self._get_prompt_from_messages(messages)
-        return max(self._client.get_num_tokens(prompts), 0)
+        return max(self._client.get_num_tokens_from_messages(prompts), 0)
 
     def get_currency(self):
         return 'RMB'
@@ -65,3 +62,7 @@ class MinimaxModel(BaseLLM):
             return LLMBadRequestError(f"Minimax: {str(ex)}")
         else:
             return ex
+
+    @property
+    def support_streaming(self):
+        return True

+ 4 - 3
api/core/model_providers/providers/minimax_provider.py

@@ -2,7 +2,7 @@ import json
 from json import JSONDecodeError
 from typing import Type
 
-from langchain.llms import Minimax
+from langchain.schema import HumanMessage
 
 from core.helper import encrypter
 from core.model_providers.models.base import BaseProviderModel
@@ -10,6 +10,7 @@ from core.model_providers.models.embedding.minimax_embedding import MinimaxEmbed
 from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
 from core.model_providers.models.llm.minimax_model import MinimaxModel
 from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
+from core.third_party.langchain.llms.minimax_llm import MinimaxChatLLM
 from models.provider import ProviderType, ProviderQuotaType
 
 
@@ -98,14 +99,14 @@ class MinimaxProvider(BaseModelProvider):
                 'minimax_api_key': credentials['minimax_api_key'],
             }
 
-            llm = Minimax(
+            llm = MinimaxChatLLM(
                 model='abab5.5-chat',
                 max_tokens=10,
                 temperature=0.01,
                 **credential_kwargs
             )
 
-            llm("ping")
+            llm([HumanMessage(content='ping')])
         except Exception as ex:
             raise CredentialsValidateFailedError(str(ex))
 

+ 273 - 0
api/core/third_party/langchain/llms/minimax_llm.py

@@ -0,0 +1,273 @@
+import json
+from typing import Dict, Any, Optional, List, Tuple, Iterator
+
+import requests
+from langchain.callbacks.manager import CallbackManagerForLLMRun
+from langchain.chat_models.base import BaseChatModel
+from langchain.llms.utils import enforce_stop_tokens
+from langchain.schema import BaseMessage, ChatResult, HumanMessage, AIMessage, SystemMessage
+from langchain.schema.messages import AIMessageChunk
+from langchain.schema.output import ChatGenerationChunk, ChatGeneration
+from langchain.utils import get_from_dict_or_env
+from pydantic import root_validator, Field, BaseModel
+
+
+class _MinimaxEndpointClient(BaseModel):
+    """An API client that talks to a Minimax llm endpoint."""
+
+    host: str
+    group_id: str
+    api_key: str
+    api_url: str
+
+    @root_validator(pre=True)
+    def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+        if "api_url" not in values:
+            host = values["host"]
+            group_id = values["group_id"]
+            api_url = f"{host}/v1/text/chatcompletion?GroupId={group_id}"
+            values["api_url"] = api_url
+        return values
+
+    def post(self, **request: Any) -> Any:
+        stream = 'stream' in request and request['stream']
+
+        headers = {"Authorization": f"Bearer {self.api_key}"}
+        response = requests.post(self.api_url, headers=headers, json=request, stream=stream, timeout=(5, 60))
+        if not response.ok:
+            raise ValueError(f"HTTP {response.status_code} error: {response.text}")
+
+        if not stream:
+            if response.json()["base_resp"]["status_code"] > 0:
+                raise ValueError(
+                    f"API {response.json()['base_resp']['status_code']}"
+                    f" error: {response.json()['base_resp']['status_msg']}"
+                )
+            return response.json()
+        else:
+            return response
+
+
+class MinimaxChatLLM(BaseChatModel):
+
+    _client: _MinimaxEndpointClient
+    model: str = "abab5.5-chat"
+    """Model name to use."""
+    max_tokens: int = 256
+    """Denotes the number of tokens to predict per generation."""
+    temperature: float = 0.7
+    """A non-negative float that tunes the degree of randomness in generation."""
+    top_p: float = 0.95
+    """Total probability mass of tokens to consider at each step."""
+    model_kwargs: Dict[str, Any] = Field(default_factory=dict)
+    """Holds any model parameters valid for `create` call not explicitly specified."""
+    streaming: bool = False
+    """Whether to stream the response or return it all at once."""
+    minimax_api_host: Optional[str] = None
+    minimax_group_id: Optional[str] = None
+    minimax_api_key: Optional[str] = None
+
+    @property
+    def lc_secrets(self) -> Dict[str, str]:
+        return {"minimax_api_key": "MINIMAX_API_KEY"}
+
+    @property
+    def lc_serializable(self) -> bool:
+        return True
+
+    @root_validator()
+    def validate_environment(cls, values: Dict) -> Dict:
+        """Validate that api key and python package exists in environment."""
+        values["minimax_api_key"] = get_from_dict_or_env(
+            values, "minimax_api_key", "MINIMAX_API_KEY"
+        )
+        values["minimax_group_id"] = get_from_dict_or_env(
+            values, "minimax_group_id", "MINIMAX_GROUP_ID"
+        )
+        # Get custom api url from environment.
+        values["minimax_api_host"] = get_from_dict_or_env(
+            values,
+            "minimax_api_host",
+            "MINIMAX_API_HOST",
+            default="https://api.minimax.chat",
+        )
+        values["_client"] = _MinimaxEndpointClient(
+            host=values["minimax_api_host"],
+            api_key=values["minimax_api_key"],
+            group_id=values["minimax_group_id"],
+        )
+        return values
+
+    @property
+    def _default_params(self) -> Dict[str, Any]:
+        """Get the default parameters for calling OpenAI API."""
+        return {
+            "model": self.model,
+            "tokens_to_generate": self.max_tokens,
+            "temperature": self.temperature,
+            "top_p": self.top_p,
+            "role_meta": {"user_name": "我", "bot_name": "专家"},
+            **self.model_kwargs,
+        }
+
+    @property
+    def _identifying_params(self) -> Dict[str, Any]:
+        """Get the identifying parameters."""
+        return {**{"model": self.model}, **self._default_params}
+
+    @property
+    def _llm_type(self) -> str:
+        """Return type of llm."""
+        return "minimax"
+
+    def _convert_message_to_dict(self, message: BaseMessage) -> dict:
+        if isinstance(message, HumanMessage):
+            message_dict = {"sender_type": "USER", "text": message.content}
+        elif isinstance(message, AIMessage):
+            message_dict = {"sender_type": "BOT", "text": message.content}
+        else:
+            raise ValueError(f"Got unknown type {message}")
+        return message_dict
+
+    def _create_messages_and_prompt(
+        self, messages: List[BaseMessage]
+    ) -> Tuple[List[Dict[str, Any]], str]:
+        prompt = ""
+        dict_messages = []
+        for m in messages:
+            if isinstance(m, SystemMessage):
+                if prompt:
+                    prompt += "\n"
+                prompt += f"{m.content}"
+                continue
+
+            message = self._convert_message_to_dict(m)
+            dict_messages.append(message)
+
+        prompt = prompt if prompt else ' '
+
+        return dict_messages, prompt
+
+    def _generate(
+            self,
+            messages: List[BaseMessage],
+            stop: Optional[List[str]] = None,
+            run_manager: Optional[CallbackManagerForLLMRun] = None,
+            **kwargs: Any,
+    ) -> ChatResult:
+        if self.streaming:
+            generation: Optional[ChatGenerationChunk] = None
+            llm_output: Optional[Dict] = None
+            for chunk in self._stream(
+                    messages=messages, stop=stop, run_manager=run_manager, **kwargs
+            ):
+                if generation is None:
+                    generation = chunk
+                else:
+                    generation += chunk
+
+                if chunk.generation_info is not None \
+                        and 'token_usage' in chunk.generation_info:
+                    llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model}
+
+            assert generation is not None
+            return ChatResult(generations=[generation], llm_output=llm_output)
+        else:
+            message_dicts, prompt = self._create_messages_and_prompt(messages)
+            params = self._default_params
+            params["messages"] = message_dicts
+            params["prompt"] = prompt
+            params.update(kwargs)
+            response = self._client.post(**params)
+            return self._create_chat_result(response, stop)
+
+    def _stream(
+            self,
+            messages: List[BaseMessage],
+            stop: Optional[List[str]] = None,
+            run_manager: Optional[CallbackManagerForLLMRun] = None,
+            **kwargs: Any,
+    ) -> Iterator[ChatGenerationChunk]:
+        message_dicts, prompt = self._create_messages_and_prompt(messages)
+        params = self._default_params
+        params["messages"] = message_dicts
+        params["prompt"] = prompt
+        params["stream"] = True
+        params.update(kwargs)
+
+        for token in self._client.post(**params).iter_lines():
+            if token:
+                token = token.decode("utf-8")
+
+                if not token.startswith("data:"):
+                    data = json.loads(token)
+                    if "base_resp" in data and data["base_resp"]["status_code"] > 0:
+                        raise ValueError(
+                            f"API {data['base_resp']['status_code']}"
+                            f" error: {data['base_resp']['status_msg']}"
+                        )
+                    else:
+                        continue
+
+                token = token.lstrip("data:").strip()
+                data = json.loads(token)
+                content = data['choices'][0]['delta']
+
+                chunk_kwargs = {
+                    'message': AIMessageChunk(content=content),
+                }
+
+                if 'usage' in data:
+                    token_usage = data['usage']
+                    overall_token_usage = {
+                        'prompt_tokens': 0,
+                        'completion_tokens': token_usage.get('total_tokens', 0),
+                        'total_tokens': token_usage.get('total_tokens', 0)
+                    }
+                    chunk_kwargs['generation_info'] = {'token_usage': overall_token_usage}
+
+                yield ChatGenerationChunk(**chunk_kwargs)
+                if run_manager:
+                    run_manager.on_llm_new_token(content)
+
+    def _create_chat_result(self, response: Dict[str, Any], stop: Optional[List[str]] = None) -> ChatResult:
+        text = response['reply']
+        if stop is not None:
+            # This is required since the stop tokens
+            # are not enforced by the model parameters
+            text = enforce_stop_tokens(text, stop)
+
+        generations = [ChatGeneration(message=AIMessage(content=text))]
+        usage = response.get("usage")
+
+        # only return total_tokens in minimax response
+        token_usage = {
+            'prompt_tokens': 0,
+            'completion_tokens': usage.get('total_tokens', 0),
+            'total_tokens': usage.get('total_tokens', 0)
+        }
+        llm_output = {"token_usage": token_usage, "model_name": self.model}
+        return ChatResult(generations=generations, llm_output=llm_output)
+
+    def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
+        """Get the number of tokens in the messages.
+
+        Useful for checking if an input will fit in a model's context window.
+
+        Args:
+            messages: The message inputs to tokenize.
+
+        Returns:
+            The sum of the number of tokens across the messages.
+        """
+        return sum([self.get_num_tokens(m.content) for m in messages])
+
+    def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
+        token_usage: dict = {}
+        for output in llm_outputs:
+            if output is None:
+                # Happens in streaming
+                continue
+            token_usage = output["token_usage"]
+
+        return {"token_usage": token_usage, "model_name": self.model}