|
@@ -1,16 +1,53 @@
|
|
|
-from typing import Optional, List, Any, Union, Generator
|
|
|
+from typing import Optional, List, Any, Union, Generator, Mapping
|
|
|
|
|
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
|
|
-from langchain.llms import Xinference
|
|
|
+from langchain.llms.base import LLM
|
|
|
from langchain.llms.utils import enforce_stop_tokens
|
|
|
-from xinference.client import (
|
|
|
+from xinference_client.client.restful.restful_client import (
|
|
|
RESTfulChatglmCppChatModelHandle,
|
|
|
RESTfulChatModelHandle,
|
|
|
- RESTfulGenerateModelHandle,
|
|
|
+ RESTfulGenerateModelHandle, Client,
|
|
|
)
|
|
|
|
|
|
|
|
|
-class XinferenceLLM(Xinference):
|
|
|
+class XinferenceLLM(LLM):
|
|
|
+ client: Any
|
|
|
+ server_url: Optional[str]
|
|
|
+ """URL of the xinference server"""
|
|
|
+ model_uid: Optional[str]
|
|
|
+ """UID of the launched model"""
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self, server_url: Optional[str] = None, model_uid: Optional[str] = None
|
|
|
+ ):
|
|
|
+ super().__init__(
|
|
|
+ **{
|
|
|
+ "server_url": server_url,
|
|
|
+ "model_uid": model_uid,
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ if self.server_url is None:
|
|
|
+ raise ValueError("Please provide server URL")
|
|
|
+
|
|
|
+ if self.model_uid is None:
|
|
|
+ raise ValueError("Please provide the model UID")
|
|
|
+
|
|
|
+ self.client = Client(server_url)
|
|
|
+
|
|
|
+ @property
|
|
|
+ def _llm_type(self) -> str:
|
|
|
+ """Return type of llm."""
|
|
|
+ return "xinference"
|
|
|
+
|
|
|
+ @property
|
|
|
+ def _identifying_params(self) -> Mapping[str, Any]:
|
|
|
+ """Get the identifying parameters."""
|
|
|
+ return {
|
|
|
+ **{"server_url": self.server_url},
|
|
|
+ **{"model_uid": self.model_uid},
|
|
|
+ }
|
|
|
+
|
|
|
def _call(
|
|
|
self,
|
|
|
prompt: str,
|