Explorar el Código

add rerank check when doing mutil-retrieval (#9998)

Jyong hace 5 meses
padre
commit
9ebd453b87

+ 1 - 1
api/core/rag/rerank/rerank_type.py

@@ -1,6 +1,6 @@
 from enum import Enum
 
 
-class RerankMode(Enum):
+class RerankMode(str, Enum):
     RERANKING_MODEL = "reranking_model"
     WEIGHTED_SCORE = "weighted_score"

+ 31 - 1
api/core/rag/retrieval/dataset_retrieval.py

@@ -22,6 +22,7 @@ from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaK
 from core.rag.datasource.retrieval_service import RetrievalService
 from core.rag.entities.context_entities import DocumentContext
 from core.rag.models.document import Document
+from core.rag.rerank.rerank_type import RerankMode
 from core.rag.retrieval.retrieval_methods import RetrievalMethod
 from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
 from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
@@ -361,10 +362,39 @@ class DatasetRetrieval:
         reranking_enable: bool = True,
         message_id: Optional[str] = None,
     ):
+        if not available_datasets:
+            return []
         threads = []
         all_documents = []
         dataset_ids = [dataset.id for dataset in available_datasets]
-        index_type = None
+        index_type_check = all(
+            item.indexing_technique == available_datasets[0].indexing_technique for item in available_datasets
+        )
+        if not index_type_check and (not reranking_enable or reranking_mode != RerankMode.RERANKING_MODEL):
+            raise ValueError(
+                "The configured knowledge base list have different indexing technique, please set reranking model."
+            )
+        index_type = available_datasets[0].indexing_technique
+        if index_type == "high_quality":
+            embedding_model_check = all(
+                item.embedding_model == available_datasets[0].embedding_model for item in available_datasets
+            )
+            embedding_model_provider_check = all(
+                item.embedding_model_provider == available_datasets[0].embedding_model_provider
+                for item in available_datasets
+            )
+            if (
+                reranking_enable
+                and reranking_mode == "weighted_score"
+                and (not embedding_model_check or not embedding_model_provider_check)
+            ):
+                raise ValueError(
+                    "The configured knowledge base list have different embedding model, please set reranking model."
+                )
+            if reranking_enable and reranking_mode == RerankMode.WEIGHTED_SCORE:
+                weights["vector_setting"]["embedding_provider_name"] = available_datasets[0].embedding_model_provider
+                weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
+
         for dataset in available_datasets:
             index_type = dataset.indexing_technique
             retrieval_thread = threading.Thread(