Ver código fonte

nvidia rerank top n missed (#8185)

Jyong 7 meses atrás
pai
commit
2d690801d1

+ 4 - 2
api/core/model_runtime/model_providers/nvidia/rerank/rerank.py

@@ -54,7 +54,6 @@ class NvidiaRerankModel(RerankModel):
                 "query": {"text": query},
                 "passages": [{"text": doc} for doc in docs],
             }
-
             session = requests.Session()
             response = session.post(invoke_url, headers=headers, json=payload)
             response.raise_for_status()
@@ -71,7 +70,10 @@ class NvidiaRerankModel(RerankModel):
                 )
 
                 rerank_documents.append(rerank_document)
-
+            if rerank_documents:
+                rerank_documents = sorted(rerank_documents, key=lambda x: x.score, reverse=True)
+                if top_n:
+                    rerank_documents = rerank_documents[:top_n]
             return RerankResult(model=model, docs=rerank_documents)
         except requests.HTTPError as e:
             raise InvokeServerUnavailableError(str(e))