|
@@ -1,14 +1,15 @@
|
|
|
import logging
|
|
|
-from typing import Optional, List
|
|
|
+from typing import List, Optional
|
|
|
|
|
|
import cohere
|
|
|
import openai
|
|
|
-from langchain.schema import Document
|
|
|
-
|
|
|
-from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
|
|
- LLMRateLimitError, LLMAuthorizationError
|
|
|
+from core.model_providers.error import (LLMAPIConnectionError,
|
|
|
+ LLMAPIUnavailableError,
|
|
|
+ LLMAuthorizationError,
|
|
|
+ LLMBadRequestError, LLMRateLimitError)
|
|
|
from core.model_providers.models.reranking.base import BaseReranking
|
|
|
from core.model_providers.providers.base import BaseModelProvider
|
|
|
+from langchain.schema import Document
|
|
|
|
|
|
|
|
|
class CohereReranking(BaseReranking):
|
|
@@ -26,10 +27,14 @@ class CohereReranking(BaseReranking):
|
|
|
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
|
|
|
docs = []
|
|
|
doc_id = []
|
|
|
+ unique_documents = []
|
|
|
for document in documents:
|
|
|
if document.metadata['doc_id'] not in doc_id:
|
|
|
doc_id.append(document.metadata['doc_id'])
|
|
|
docs.append(document.page_content)
|
|
|
+ unique_documents.append(document)
|
|
|
+ documents = unique_documents
|
|
|
+
|
|
|
results = self.client.rerank(query=query, documents=docs, model=self.name, top_n=top_k)
|
|
|
rerank_documents = []
|
|
|
|