Преглед изворни кода

Hotfix/fix documents index mismatch error in rerank (#1662)

Co-authored-by: baomi.wbm <baomi.wbm@dtwave-inc.com>
WangBooth пре 1 година
родитељ
комит
22bc9ddc73

+ 10 - 5
api/core/model_providers/models/reranking/cohere_reranking.py

@@ -1,14 +1,15 @@
 import logging
-from typing import Optional, List
+from typing import List, Optional
 
 import cohere
 import openai
-from langchain.schema import Document
-
-from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
-    LLMRateLimitError, LLMAuthorizationError
+from core.model_providers.error import (LLMAPIConnectionError,
+                                        LLMAPIUnavailableError,
+                                        LLMAuthorizationError,
+                                        LLMBadRequestError, LLMRateLimitError)
 from core.model_providers.models.reranking.base import BaseReranking
 from core.model_providers.providers.base import BaseModelProvider
+from langchain.schema import Document
 
 
 class CohereReranking(BaseReranking):
@@ -26,10 +27,14 @@ class CohereReranking(BaseReranking):
     def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
         docs = []
         doc_id = []
+        unique_documents = []
         for document in documents:
             if document.metadata['doc_id'] not in doc_id:
                 doc_id.append(document.metadata['doc_id'])
                 docs.append(document.page_content)
+                unique_documents.append(document)
+        documents = unique_documents
+        
         results = self.client.rerank(query=query, documents=docs, model=self.name, top_n=top_k)
         rerank_documents = []
 

+ 4 - 1
api/core/model_providers/models/reranking/xinference_reranking.py

@@ -23,11 +23,14 @@ class XinferenceReranking(BaseReranking):
     def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
         docs = []
         doc_id = []
+        unique_documents = []
         for document in documents:
             if document.metadata['doc_id'] not in doc_id:
                 doc_id.append(document.metadata['doc_id'])
                 docs.append(document.page_content)
-
+                unique_documents.append(document)
+        documents = unique_documents
+        
         model = self.client.get_model(self.credentials['model_uid'])
         response = model.rerank(query=query, documents=docs, top_n=top_k)
         rerank_documents = []