|
@@ -51,17 +51,22 @@ class XinferenceRerankModel(RerankModel):
|
|
|
server_url = server_url[:-1]
|
|
|
auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
|
|
|
|
|
|
+ params = {
|
|
|
+ 'documents': docs,
|
|
|
+ 'query': query,
|
|
|
+ 'top_n': top_n,
|
|
|
+ 'return_documents': True
|
|
|
+ }
|
|
|
try:
|
|
|
handle = RESTfulRerankModelHandle(model_uid, server_url, auth_headers)
|
|
|
- response = handle.rerank(
|
|
|
- documents=docs,
|
|
|
- query=query,
|
|
|
- top_n=top_n,
|
|
|
- return_documents=True
|
|
|
- )
|
|
|
+ response = handle.rerank(**params)
|
|
|
except RuntimeError as e:
|
|
|
- raise InvokeServerUnavailableError(str(e))
|
|
|
+ if "rerank hasn't support extra parameter" not in str(e):
|
|
|
+ raise InvokeServerUnavailableError(str(e))
|
|
|
|
|
|
+ # compatible xinference server between v0.10.1 - v0.12.1, not support 'return_len'
|
|
|
+ handle = RESTfulRerankModelHandleWithoutExtraParameter(model_uid, server_url, auth_headers)
|
|
|
+ response = handle.rerank(**params)
|
|
|
|
|
|
rerank_documents = []
|
|
|
for idx, result in enumerate(response['results']):
|
|
@@ -167,8 +172,40 @@ class XinferenceRerankModel(RerankModel):
|
|
|
),
|
|
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
|
|
model_type=ModelType.RERANK,
|
|
|
- model_properties={ },
|
|
|
+ model_properties={},
|
|
|
parameter_rules=[]
|
|
|
)
|
|
|
|
|
|
return entity
|
|
|
+
|
|
|
+
|
|
|
+class RESTfulRerankModelHandleWithoutExtraParameter(RESTfulRerankModelHandle):
|
|
|
+
|
|
|
+ def rerank(
|
|
|
+ self,
|
|
|
+ documents: list[str],
|
|
|
+ query: str,
|
|
|
+ top_n: Optional[int] = None,
|
|
|
+ max_chunks_per_doc: Optional[int] = None,
|
|
|
+ return_documents: Optional[bool] = None,
|
|
|
+ **kwargs
|
|
|
+ ):
|
|
|
+ url = f"{self._base_url}/v1/rerank"
|
|
|
+ request_body = {
|
|
|
+ "model": self._model_uid,
|
|
|
+ "documents": documents,
|
|
|
+ "query": query,
|
|
|
+ "top_n": top_n,
|
|
|
+ "max_chunks_per_doc": max_chunks_per_doc,
|
|
|
+ "return_documents": return_documents,
|
|
|
+ }
|
|
|
+
|
|
|
+ import requests
|
|
|
+
|
|
|
+ response = requests.post(url, json=request_body, headers=self.auth_headers)
|
|
|
+ if response.status_code != 200:
|
|
|
+ raise InvokeServerUnavailableError(
|
|
|
+ f"Failed to rerank documents, detail: {response.json()['detail']}"
|
|
|
+ )
|
|
|
+ response_data = response.json()
|
|
|
+ return response_data
|