|
@@ -1,7 +1,8 @@
|
|
|
import json
|
|
|
from typing import Type
|
|
|
|
|
|
-from langchain.llms import Xinference
|
|
|
+import requests
|
|
|
+from xinference.client import RESTfulGenerateModelHandle, RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle
|
|
|
|
|
|
from core.helper import encrypter
|
|
|
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
|
|
@@ -10,6 +11,7 @@ from core.model_providers.models.llm.xinference_model import XinferenceModel
|
|
|
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
|
|
|
|
|
from core.model_providers.models.base import BaseProviderModel
|
|
|
+from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
|
|
|
from models.provider import ProviderType
|
|
|
|
|
|
|
|
@@ -48,13 +50,32 @@ class XinferenceProvider(BaseModelProvider):
|
|
|
:param model_type:
|
|
|
:return:
|
|
|
"""
|
|
|
- return ModelKwargsRules(
|
|
|
- temperature=KwargRule[float](min=0, max=2, default=1),
|
|
|
- top_p=KwargRule[float](min=0, max=1, default=0.7),
|
|
|
- presence_penalty=KwargRule[float](min=-2, max=2, default=0),
|
|
|
- frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
|
|
|
- max_tokens=KwargRule[int](min=10, max=4000, default=256),
|
|
|
- )
|
|
|
+ credentials = self.get_model_credentials(model_name, model_type)
|
|
|
+ if credentials['model_format'] == "ggmlv3" and credentials["model_handle_type"] == "chatglm":
|
|
|
+ return ModelKwargsRules(
|
|
|
+ temperature=KwargRule[float](min=0.01, max=2, default=1),
|
|
|
+ top_p=KwargRule[float](min=0, max=1, default=0.7),
|
|
|
+ presence_penalty=KwargRule[float](enabled=False),
|
|
|
+ frequency_penalty=KwargRule[float](enabled=False),
|
|
|
+ max_tokens=KwargRule[int](min=10, max=4000, default=256),
|
|
|
+ )
|
|
|
+ elif credentials['model_format'] == "ggmlv3":
|
|
|
+ return ModelKwargsRules(
|
|
|
+ temperature=KwargRule[float](min=0.01, max=2, default=1),
|
|
|
+ top_p=KwargRule[float](min=0, max=1, default=0.7),
|
|
|
+ presence_penalty=KwargRule[float](min=-2, max=2, default=0),
|
|
|
+ frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
|
|
|
+ max_tokens=KwargRule[int](min=10, max=4000, default=256),
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ return ModelKwargsRules(
|
|
|
+ temperature=KwargRule[float](min=0.01, max=2, default=1),
|
|
|
+ top_p=KwargRule[float](min=0, max=1, default=0.7),
|
|
|
+ presence_penalty=KwargRule[float](enabled=False),
|
|
|
+ frequency_penalty=KwargRule[float](enabled=False),
|
|
|
+ max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=256),
|
|
|
+ )
|
|
|
+
|
|
|
|
|
|
@classmethod
|
|
|
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
|
@@ -77,11 +98,11 @@ class XinferenceProvider(BaseModelProvider):
|
|
|
'model_uid': credentials['model_uid'],
|
|
|
}
|
|
|
|
|
|
- llm = Xinference(
|
|
|
+ llm = XinferenceLLM(
|
|
|
**credential_kwargs
|
|
|
)
|
|
|
|
|
|
- llm("ping", generate_config={'max_tokens': 10})
|
|
|
+ llm("ping")
|
|
|
except Exception as ex:
|
|
|
raise CredentialsValidateFailedError(str(ex))
|
|
|
|
|
@@ -97,7 +118,11 @@ class XinferenceProvider(BaseModelProvider):
|
|
|
:param credentials:
|
|
|
:return:
|
|
|
"""
|
|
|
+ extra_credentials = cls._get_extra_credentials(credentials)
|
|
|
+ credentials.update(extra_credentials)
|
|
|
+
|
|
|
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
|
|
|
+
|
|
|
return credentials
|
|
|
|
|
|
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
|
@@ -132,6 +157,30 @@ class XinferenceProvider(BaseModelProvider):
|
|
|
|
|
|
return credentials
|
|
|
|
|
|
+ @classmethod
|
|
|
+ def _get_extra_credentials(self, credentials: dict) -> dict:
|
|
|
+ url = f"{credentials['server_url']}/v1/models/{credentials['model_uid']}"
|
|
|
+ response = requests.get(url)
|
|
|
+ if response.status_code != 200:
|
|
|
+ raise RuntimeError(
|
|
|
+ f"Failed to get the model description, detail: {response.json()['detail']}"
|
|
|
+ )
|
|
|
+ desc = response.json()
|
|
|
+
|
|
|
+ extra_credentials = {
|
|
|
+ 'model_format': desc['model_format'],
|
|
|
+ }
|
|
|
+ if desc["model_format"] == "ggmlv3" and "chatglm" in desc["model_name"]:
|
|
|
+ extra_credentials['model_handle_type'] = 'chatglm'
|
|
|
+ elif "generate" in desc["model_ability"]:
|
|
|
+ extra_credentials['model_handle_type'] = 'generate'
|
|
|
+ elif "chat" in desc["model_ability"]:
|
|
|
+ extra_credentials['model_handle_type'] = 'chat'
|
|
|
+ else:
|
|
|
+ raise NotImplementedError(f"Model handle type not supported.")
|
|
|
+
|
|
|
+ return extra_credentials
|
|
|
+
|
|
|
@classmethod
|
|
|
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
|
|
return
|