Forráskód Böngészése

compatible xinference reranker server (#6927)

Weaxs 8 hónapja
szülő
commit
5e634a59a2

+ 45 - 8
api/core/model_runtime/model_providers/xinference/rerank/rerank.py

@@ -51,17 +51,22 @@ class XinferenceRerankModel(RerankModel):
             server_url = server_url[:-1]
         auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
 
+        params = {
+            'documents': docs,
+            'query': query,
+            'top_n': top_n,
+            'return_documents': True
+        }
         try:
             handle = RESTfulRerankModelHandle(model_uid, server_url, auth_headers)
-            response = handle.rerank(
-                documents=docs,
-                query=query,
-                top_n=top_n,
-                return_documents=True
-            )
+            response = handle.rerank(**params)
         except RuntimeError as e:
-            raise InvokeServerUnavailableError(str(e))
+            if "rerank hasn't support extra parameter" not in str(e):
+                raise InvokeServerUnavailableError(str(e))
 
+            # compatible xinference server between v0.10.1 - v0.12.1, not support 'return_len'
+            handle = RESTfulRerankModelHandleWithoutExtraParameter(model_uid, server_url, auth_headers)
+            response = handle.rerank(**params)
 
         rerank_documents = []
         for idx, result in enumerate(response['results']):
@@ -167,8 +172,40 @@ class XinferenceRerankModel(RerankModel):
             ),
             fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
             model_type=ModelType.RERANK,
-            model_properties={ },
+            model_properties={},
             parameter_rules=[]
         )
 
         return entity
+
+
+class RESTfulRerankModelHandleWithoutExtraParameter(RESTfulRerankModelHandle):
+
+    def rerank(
+            self,
+            documents: list[str],
+            query: str,
+            top_n: Optional[int] = None,
+            max_chunks_per_doc: Optional[int] = None,
+            return_documents: Optional[bool] = None,
+            **kwargs
+    ):
+        url = f"{self._base_url}/v1/rerank"
+        request_body = {
+            "model": self._model_uid,
+            "documents": documents,
+            "query": query,
+            "top_n": top_n,
+            "max_chunks_per_doc": max_chunks_per_doc,
+            "return_documents": return_documents,
+        }
+
+        import requests
+
+        response = requests.post(url, json=request_body, headers=self.auth_headers)
+        if response.status_code != 200:
+            raise InvokeServerUnavailableError(
+                f"Failed to rerank documents, detail: {response.json()['detail']}"
+            )
+        response_data = response.json()
+        return response_data