Parcourir la source

feat: use xinference client instead of xinference (#1339)

takatost il y a 1 an
Parent
commit
3efaa713da

+ 1 - 2
api/core/model_providers/models/embedding/xinference_embedding.py

@@ -1,8 +1,7 @@
-from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbedding as XinferenceEmbeddings
-
 from core.model_providers.error import LLMBadRequestError
 from core.model_providers.providers.base import BaseModelProvider
 from core.model_providers.models.embedding.base import BaseEmbedding
+from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbeddings
 
 
 class XinferenceEmbedding(BaseEmbedding):

+ 1 - 1
api/core/model_providers/providers/xinference_provider.py

@@ -2,7 +2,6 @@ import json
 from typing import Type
 
 import requests
-from langchain.embeddings import XinferenceEmbeddings
 
 from core.helper import encrypter
 from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
@@ -11,6 +10,7 @@ from core.model_providers.models.llm.xinference_model import XinferenceModel
 from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
 
 from core.model_providers.models.base import BaseProviderModel
+from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbeddings
 from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
 from models.provider import ProviderType
 

+ 38 - 5
api/core/third_party/langchain/embeddings/xinference_embedding.py

@@ -1,21 +1,54 @@
-from typing import List
+from typing import List, Optional, Any
 
 import numpy as np
-from langchain.embeddings import XinferenceEmbeddings
+from langchain.embeddings.base import Embeddings
+from xinference_client.client.restful.restful_client import Client
 
 
-class XinferenceEmbedding(XinferenceEmbeddings):
+class XinferenceEmbeddings(Embeddings):
+    client: Any
+    server_url: Optional[str]
+    """URL of the xinference server"""
+    model_uid: Optional[str]
+    """UID of the launched model"""
+
+    def __init__(
+            self, server_url: Optional[str] = None, model_uid: Optional[str] = None
+    ):
+
+        super().__init__()
+
+        if server_url is None:
+            raise ValueError("Please provide server URL")
+
+        if model_uid is None:
+            raise ValueError("Please provide the model UID")
+
+        self.server_url = server_url
+
+        self.model_uid = model_uid
+
+        self.client = Client(server_url)
 
     def embed_documents(self, texts: List[str]) -> List[List[float]]:
-        vectors = super().embed_documents(texts)
+        model = self.client.get_model(self.model_uid)
 
+        embeddings = [
+            model.create_embedding(text)["data"][0]["embedding"] for text in texts
+        ]
+        vectors = [list(map(float, e)) for e in embeddings]
         normalized_vectors = [(vector / np.linalg.norm(vector)).tolist() for vector in vectors]
 
         return normalized_vectors
 
     def embed_query(self, text: str) -> List[float]:
-        vector = super().embed_query(text)
+        model = self.client.get_model(self.model_uid)
+
+        embedding_res = model.create_embedding(text)
+
+        embedding = embedding_res["data"][0]["embedding"]
 
+        vector = list(map(float, embedding))
         normalized_vector = (vector / np.linalg.norm(vector)).tolist()
 
         return normalized_vector

+ 42 - 5
api/core/third_party/langchain/llms/xinference_llm.py

@@ -1,16 +1,53 @@
-from typing import Optional, List, Any, Union, Generator
+from typing import Optional, List, Any, Union, Generator, Mapping
 
 from langchain.callbacks.manager import CallbackManagerForLLMRun
-from langchain.llms import Xinference
+from langchain.llms.base import LLM
 from langchain.llms.utils import enforce_stop_tokens
-from xinference.client import (
+from xinference_client.client.restful.restful_client import (
     RESTfulChatglmCppChatModelHandle,
     RESTfulChatModelHandle,
-    RESTfulGenerateModelHandle,
+    RESTfulGenerateModelHandle, Client,
 )
 
 
-class XinferenceLLM(Xinference):
+class XinferenceLLM(LLM):
+    client: Any
+    server_url: Optional[str]
+    """URL of the xinference server"""
+    model_uid: Optional[str]
+    """UID of the launched model"""
+
+    def __init__(
+            self, server_url: Optional[str] = None, model_uid: Optional[str] = None
+    ):
+        super().__init__(
+            **{
+                "server_url": server_url,
+                "model_uid": model_uid,
+            }
+        )
+
+        if self.server_url is None:
+            raise ValueError("Please provide server URL")
+
+        if self.model_uid is None:
+            raise ValueError("Please provide the model UID")
+
+        self.client = Client(server_url)
+
+    @property
+    def _llm_type(self) -> str:
+        """Return type of llm."""
+        return "xinference"
+
+    @property
+    def _identifying_params(self) -> Mapping[str, Any]:
+        """Get the identifying parameters."""
+        return {
+            **{"server_url": self.server_url},
+            **{"model_uid": self.model_uid},
+        }
+
     def _call(
         self,
         prompt: str,

+ 1 - 1
api/requirements.txt

@@ -49,7 +49,7 @@ huggingface_hub~=0.16.4
 transformers~=4.31.0
 stripe~=5.5.0
 pandas==1.5.3
-xinference==0.5.2
+xinference-client~=0.1.2
 safetensors==0.3.2
 zhipuai==1.0.7
 werkzeug==2.3.7