소스 검색

fix: xinference chat support (#939)

takatost 1 년 전
부모
커밋
e0a48c4972

+ 4 - 3
api/core/model_providers/models/llm/xinference_model.py

@@ -1,13 +1,13 @@
 from typing import List, Optional, Any
 
 from langchain.callbacks.manager import Callbacks
-from langchain.llms import Xinference
 from langchain.schema import LLMResult
 
 from core.model_providers.error import LLMBadRequestError
 from core.model_providers.models.llm.base import BaseLLM
 from core.model_providers.models.entity.message import PromptMessage
 from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
+from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
 
 
 class XinferenceModel(BaseLLM):
@@ -16,8 +16,9 @@ class XinferenceModel(BaseLLM):
     def _init_client(self) -> Any:
         self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
 
-        client = Xinference(
-            **self.credentials,
+        client = XinferenceLLM(
+            server_url=self.credentials['server_url'],
+            model_uid=self.credentials['model_uid'],
         )
 
         client.callbacks = self.callbacks

+ 59 - 10
api/core/model_providers/providers/xinference_provider.py

@@ -1,7 +1,8 @@
 import json
 from typing import Type
 
-from langchain.llms import Xinference
+import requests
+from xinference.client import RESTfulGenerateModelHandle, RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle
 
 from core.helper import encrypter
 from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
@@ -10,6 +11,7 @@ from core.model_providers.models.llm.xinference_model import XinferenceModel
 from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
 
 from core.model_providers.models.base import BaseProviderModel
+from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
 from models.provider import ProviderType
 
 
@@ -48,13 +50,32 @@ class XinferenceProvider(BaseModelProvider):
         :param model_type:
         :return:
         """
-        return ModelKwargsRules(
-            temperature=KwargRule[float](min=0, max=2, default=1),
-            top_p=KwargRule[float](min=0, max=1, default=0.7),
-            presence_penalty=KwargRule[float](min=-2, max=2, default=0),
-            frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
-            max_tokens=KwargRule[int](min=10, max=4000, default=256),
-        )
+        credentials = self.get_model_credentials(model_name, model_type)
+        if credentials['model_format'] == "ggmlv3" and credentials["model_handle_type"] == "chatglm":
+            return ModelKwargsRules(
+                temperature=KwargRule[float](min=0.01, max=2, default=1),
+                top_p=KwargRule[float](min=0, max=1, default=0.7),
+                presence_penalty=KwargRule[float](enabled=False),
+                frequency_penalty=KwargRule[float](enabled=False),
+                max_tokens=KwargRule[int](min=10, max=4000, default=256),
+            )
+        elif credentials['model_format'] == "ggmlv3":
+            return ModelKwargsRules(
+                temperature=KwargRule[float](min=0.01, max=2, default=1),
+                top_p=KwargRule[float](min=0, max=1, default=0.7),
+                presence_penalty=KwargRule[float](min=-2, max=2, default=0),
+                frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
+                max_tokens=KwargRule[int](min=10, max=4000, default=256),
+            )
+        else:
+            return ModelKwargsRules(
+                temperature=KwargRule[float](min=0.01, max=2, default=1),
+                top_p=KwargRule[float](min=0, max=1, default=0.7),
+                presence_penalty=KwargRule[float](enabled=False),
+                frequency_penalty=KwargRule[float](enabled=False),
+                max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=256),
+            )
+
 
     @classmethod
     def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
@@ -77,11 +98,11 @@ class XinferenceProvider(BaseModelProvider):
                 'model_uid': credentials['model_uid'],
             }
 
-            llm = Xinference(
+            llm = XinferenceLLM(
                 **credential_kwargs
             )
 
-            llm("ping", generate_config={'max_tokens': 10})
+            llm("ping")
         except Exception as ex:
             raise CredentialsValidateFailedError(str(ex))
 
@@ -97,7 +118,11 @@ class XinferenceProvider(BaseModelProvider):
         :param credentials:
         :return:
         """
+        extra_credentials = cls._get_extra_credentials(credentials)
+        credentials.update(extra_credentials)
+
         credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
+
         return credentials
 
     def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
@@ -132,6 +157,30 @@ class XinferenceProvider(BaseModelProvider):
 
         return credentials
 
+    @classmethod
+    def _get_extra_credentials(self, credentials: dict) -> dict:
+        url = f"{credentials['server_url']}/v1/models/{credentials['model_uid']}"
+        response = requests.get(url)
+        if response.status_code != 200:
+            raise RuntimeError(
+                f"Failed to get the model description, detail: {response.json()['detail']}"
+            )
+        desc = response.json()
+
+        extra_credentials = {
+            'model_format': desc['model_format'],
+        }
+        if desc["model_format"] == "ggmlv3" and "chatglm" in desc["model_name"]:
+            extra_credentials['model_handle_type'] = 'chatglm'
+        elif "generate" in desc["model_ability"]:
+            extra_credentials['model_handle_type'] = 'generate'
+        elif "chat" in desc["model_ability"]:
+            extra_credentials['model_handle_type'] = 'chat'
+        else:
+            raise NotImplementedError(f"Model handle type not supported.")
+
+        return extra_credentials
+
     @classmethod
     def is_provider_credentials_valid_or_raise(cls, credentials: dict):
         return

+ 132 - 0
api/core/third_party/langchain/llms/xinference_llm.py

@@ -0,0 +1,132 @@
+from typing import Optional, List, Any, Union, Generator
+
+from langchain.callbacks.manager import CallbackManagerForLLMRun
+from langchain.llms import Xinference
+from langchain.llms.utils import enforce_stop_tokens
+from xinference.client import RESTfulChatglmCppChatModelHandle, \
+    RESTfulChatModelHandle, RESTfulGenerateModelHandle
+
+
+class XinferenceLLM(Xinference):
+    def _call(
+        self,
+        prompt: str,
+        stop: Optional[List[str]] = None,
+        run_manager: Optional[CallbackManagerForLLMRun] = None,
+        **kwargs: Any,
+    ) -> str:
+        """Call the xinference model and return the output.
+
+        Args:
+            prompt: The prompt to use for generation.
+            stop: Optional list of stop words to use when generating.
+            generate_config: Optional dictionary for the configuration used for
+                generation.
+
+        Returns:
+            The generated string by the model.
+        """
+        model = self.client.get_model(self.model_uid)
+
+        if isinstance(model, RESTfulChatModelHandle):
+            generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})
+
+            if stop:
+                generate_config["stop"] = stop
+
+            if generate_config and generate_config.get("stream"):
+                combined_text_output = ""
+                for token in self._stream_generate(
+                        model=model,
+                        prompt=prompt,
+                        run_manager=run_manager,
+                        generate_config=generate_config,
+                ):
+                    combined_text_output += token
+                return combined_text_output
+            else:
+                completion = model.chat(prompt=prompt, generate_config=generate_config)
+                return completion["choices"][0]["text"]
+        elif isinstance(model, RESTfulGenerateModelHandle):
+            generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})
+
+            if stop:
+                generate_config["stop"] = stop
+
+            if generate_config and generate_config.get("stream"):
+                combined_text_output = ""
+                for token in self._stream_generate(
+                    model=model,
+                    prompt=prompt,
+                    run_manager=run_manager,
+                    generate_config=generate_config,
+                ):
+                    combined_text_output += token
+                return combined_text_output
+
+            else:
+                completion = model.generate(prompt=prompt, generate_config=generate_config)
+                return completion["choices"][0]["text"]
+        elif isinstance(model, RESTfulChatglmCppChatModelHandle):
+            generate_config: "ChatglmCppGenerateConfig" = kwargs.get("generate_config", {})
+
+            if generate_config and generate_config.get("stream"):
+                combined_text_output = ""
+                for token in self._stream_generate(
+                    model=model,
+                    prompt=prompt,
+                    run_manager=run_manager,
+                    generate_config=generate_config,
+                ):
+                    combined_text_output += token
+                completion = combined_text_output
+            else:
+                completion = model.chat(prompt=prompt, generate_config=generate_config)
+                completion = completion["choices"][0]["text"]
+
+            if stop is not None:
+                completion = enforce_stop_tokens(completion, stop)
+
+            return completion
+
+
+    def _stream_generate(
+        self,
+        model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"],
+        prompt: str,
+        run_manager: Optional[CallbackManagerForLLMRun] = None,
+        generate_config: Optional[Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None,
+    ) -> Generator[str, None, None]:
+        """
+        Args:
+            prompt: The prompt to use for generation.
+            model: The model used for generation.
+            stop: Optional list of stop words to use when generating.
+            generate_config: Optional dictionary for the configuration used for
+                generation.
+
+        Yields:
+            A string token.
+        """
+        if isinstance(model, RESTfulGenerateModelHandle):
+            streaming_response = model.generate(
+                prompt=prompt, generate_config=generate_config
+            )
+        else:
+            streaming_response = model.chat(
+                prompt=prompt, generate_config=generate_config
+            )
+
+        for chunk in streaming_response:
+            if isinstance(chunk, dict):
+                choices = chunk.get("choices", [])
+                if choices:
+                    choice = choices[0]
+                    if isinstance(choice, dict):
+                        token = choice.get("text", "")
+                        log_probs = choice.get("logprobs")
+                        if run_manager:
+                            run_manager.on_llm_new_token(
+                                token=token, verbose=self.verbose, log_probs=log_probs
+                            )
+                        yield token

+ 9 - 3
api/tests/unit_tests/model_providers/test_xinference_provider.py

@@ -4,7 +4,6 @@ import json
 
 from core.model_providers.models.entity.model_params import ModelType
 from core.model_providers.providers.base import CredentialsValidateFailedError
-from core.model_providers.providers.replicate_provider import ReplicateProvider
 from core.model_providers.providers.xinference_provider import XinferenceProvider
 from models.provider import ProviderType, Provider, ProviderModel
 
@@ -25,7 +24,7 @@ def decrypt_side_effect(tenant_id, encrypted_key):
 
 
 def test_is_credentials_valid_or_raise_valid(mocker):
-    mocker.patch('langchain.llms.xinference.Xinference._call',
+    mocker.patch('core.third_party.langchain.llms.xinference_llm.XinferenceLLM._call',
                  return_value="abc")
 
     MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
@@ -53,8 +52,15 @@ def test_is_credentials_valid_or_raise_invalid():
 
 
 @patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
-def test_encrypt_model_credentials(mock_encrypt):
+def test_encrypt_model_credentials(mock_encrypt, mocker):
     api_key = 'http://127.0.0.1:9997/'
+
+    mocker.patch('core.model_providers.providers.xinference_provider.XinferenceProvider._get_extra_credentials',
+                 return_value={
+                     'model_handle_type': 'generate',
+                     'model_format': 'ggmlv3'
+                 })
+
     result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
         tenant_id='tenant_id',
         model_name='test_model_name',