Parcourir la source

fix: xinference reranker return_documents (#6888)

Weaxs il y a 8 mois
Parent
commit
cc4785f094

+ 2 - 1
api/core/model_runtime/model_providers/xinference/rerank/rerank.py

@@ -57,6 +57,7 @@ class XinferenceRerankModel(RerankModel):
                 documents=docs,
                 query=query,
                 top_n=top_n,
+                return_documents=True
             )
         except RuntimeError as e:
             raise InvokeServerUnavailableError(str(e))
@@ -66,7 +67,7 @@ class XinferenceRerankModel(RerankModel):
         for idx, result in enumerate(response['results']):
             # format document
             index = result['index']
-            page_content = result['document']
+            page_content = result['document'] if isinstance(result['document'], str) else result['document']['text']
             rerank_document = RerankDocument(
                 index=index,
                 text=page_content,

+ 1 - 1
api/tests/integration_tests/model_runtime/__mock/xinference.py

@@ -106,7 +106,7 @@ class MockXinferenceClass:
     def _check_cluster_authenticated(self):
         self._cluster_authed = True
         
-    def rerank(self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int) -> dict:
+    def rerank(self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool) -> dict:
         # check if self._model_uid is a valid uuid
         if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
             self._model_uid != 'rerank':