Просмотр исходного кода

Feat/delete single dataset retrival (#6570)

Jyong 9 месяцев назад
Родитель
Сommit
e4bb943fe5

+ 9 - 3
api/core/app/app_config/easy_ui_based_app/dataset/manager.py

@@ -62,7 +62,12 @@ class DatasetConfigManager:
             return None
 
         # dataset configs
-        dataset_configs = config.get('dataset_configs', {'retrieval_model': 'single'})
+        if 'dataset_configs' in config and config.get('dataset_configs'):
+            dataset_configs = config.get('dataset_configs')
+        else:
+            dataset_configs = {
+                'retrieval_model': 'multiple'
+            }
         query_variable = config.get('dataset_query_variable')
 
         if dataset_configs['retrieval_model'] == 'single':
@@ -83,9 +88,10 @@ class DatasetConfigManager:
                     retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
                         dataset_configs['retrieval_model']
                     ),
-                    top_k=dataset_configs.get('top_k'),
+                    top_k=dataset_configs.get('top_k', 4),
                     score_threshold=dataset_configs.get('score_threshold'),
-                    reranking_model=dataset_configs.get('reranking_model')
+                    reranking_model=dataset_configs.get('reranking_model'),
+                    weights=dataset_configs.get('weights')
                 )
             )
 

+ 4 - 0
api/core/app/app_config/entities.py

@@ -159,7 +159,11 @@ class DatasetRetrieveConfigEntity(BaseModel):
     retrieve_strategy: RetrieveStrategy
     top_k: Optional[int] = None
     score_threshold: Optional[float] = None
+    rerank_mode: Optional[str] = 'reranking_model'
     reranking_model: Optional[dict] = None
+    weights: Optional[dict] = None
+
+
 
 
 class DatasetEntity(BaseModel):

+ 38 - 15
api/core/rag/data_post_processor/data_post_processor.py

@@ -5,15 +5,20 @@ from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.errors.invoke import InvokeAuthorizationError
 from core.rag.data_post_processor.reorder import ReorderRunner
 from core.rag.models.document import Document
-from core.rag.rerank.rerank import RerankRunner
+from core.rag.rerank.constants.rerank_mode import RerankMode
+from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights
+from core.rag.rerank.rerank_model import RerankModelRunner
+from core.rag.rerank.weight_rerank import WeightRerankRunner
 
 
 class DataPostProcessor:
     """Interface for data post-processing document.
     """
 
-    def __init__(self, tenant_id: str, reranking_model: dict, reorder_enabled: bool = False):
-        self.rerank_runner = self._get_rerank_runner(reranking_model, tenant_id)
+    def __init__(self, tenant_id: str, reranking_mode: str,
+                 reranking_model: Optional[dict] = None, weights: Optional[dict] = None,
+                 reorder_enabled: bool = False):
+        self.rerank_runner = self._get_rerank_runner(reranking_mode, tenant_id, reranking_model, weights)
         self.reorder_runner = self._get_reorder_runner(reorder_enabled)
 
     def invoke(self, query: str, documents: list[Document], score_threshold: Optional[float] = None,
@@ -26,19 +31,37 @@ class DataPostProcessor:
 
         return documents
 
-    def _get_rerank_runner(self, reranking_model: dict, tenant_id: str) -> Optional[RerankRunner]:
-        if reranking_model:
-            try:
-                model_manager = ModelManager()
-                rerank_model_instance = model_manager.get_model_instance(
-                    tenant_id=tenant_id,
-                    provider=reranking_model['reranking_provider_name'],
-                    model_type=ModelType.RERANK,
-                    model=reranking_model['reranking_model_name']
+    def _get_rerank_runner(self, reranking_mode: str, tenant_id: str, reranking_model: Optional[dict] = None,
+                           weights: Optional[dict] = None) -> Optional[RerankModelRunner | WeightRerankRunner]:
+        if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights:
+            return WeightRerankRunner(
+                tenant_id,
+                Weights(
+                    weight_type=weights['weight_type'],
+                    vector_setting=VectorSetting(
+                        vector_weight=weights['vector_setting']['vector_weight'],
+                        embedding_provider_name=weights['vector_setting']['embedding_provider_name'],
+                        embedding_model_name=weights['vector_setting']['embedding_model_name'],
+                    ),
+                    keyword_setting=KeywordSetting(
+                        keyword_weight=weights['keyword_setting']['keyword_weight'],
+                    )
                 )
-            except InvokeAuthorizationError:
-                return None
-            return RerankRunner(rerank_model_instance)
+            )
+        elif reranking_mode == RerankMode.RERANKING_MODEL.value:
+            if reranking_model:
+                try:
+                    model_manager = ModelManager()
+                    rerank_model_instance = model_manager.get_model_instance(
+                        tenant_id=tenant_id,
+                        provider=reranking_model['reranking_provider_name'],
+                        model_type=ModelType.RERANK,
+                        model=reranking_model['reranking_model_name']
+                    )
+                except InvokeAuthorizationError:
+                    return None
+                return RerankModelRunner(rerank_model_instance)
+            return None
         return None
 
     def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]:

+ 2 - 1
api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py

@@ -1,4 +1,5 @@
 import re
+from typing import Optional
 
 import jieba
 from jieba.analyse import default_tfidf
@@ -11,7 +12,7 @@ class JiebaKeywordTableHandler:
     def __init__(self):
         default_tfidf.stop_words = STOPWORDS
 
-    def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> set[str]:
+    def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]:
         """Extract keywords with JIEBA tfidf."""
         keywords = jieba.analyse.extract_tags(
             sentence=text,

+ 16 - 4
api/core/rag/datasource/retrieval_service.py

@@ -6,6 +6,7 @@ from flask import Flask, current_app
 from core.rag.data_post_processor.data_post_processor import DataPostProcessor
 from core.rag.datasource.keyword.keyword_factory import Keyword
 from core.rag.datasource.vdb.vector_factory import Vector
+from core.rag.rerank.constants.rerank_mode import RerankMode
 from core.rag.retrieval.retrival_methods import RetrievalMethod
 from extensions.ext_database import db
 from models.dataset import Dataset
@@ -26,13 +27,19 @@ class RetrievalService:
 
     @classmethod
     def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
-                 top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None):
+                 top_k: int, score_threshold: Optional[float] = .0,
+                 reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = None,
+                 weights: Optional[dict] = None):
         dataset = db.session.query(Dataset).filter(
             Dataset.id == dataset_id
         ).first()
         if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
             return []
         all_documents = []
+        keyword_search_documents = []
+        embedding_search_documents = []
+        full_text_search_documents = []
+        hybrid_search_documents = []
         threads = []
         exceptions = []
         # retrieval_model source with keyword
@@ -87,7 +94,8 @@ class RetrievalService:
             raise Exception(exception_message)
 
         if retrival_method == RetrievalMethod.HYBRID_SEARCH.value:
-            data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
+            data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_mode,
+                                                    reranking_model, weights, False)
             all_documents = data_post_processor.invoke(
                 query=query,
                 documents=all_documents,
@@ -143,7 +151,9 @@ class RetrievalService:
 
                 if documents:
                     if reranking_model and retrival_method == RetrievalMethod.SEMANTIC_SEARCH.value:
-                        data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
+                        data_post_processor = DataPostProcessor(str(dataset.tenant_id),
+                                                                RerankMode.RERANKING_MODEL.value,
+                                                                reranking_model, None, False)
                         all_documents.extend(data_post_processor.invoke(
                             query=query,
                             documents=documents,
@@ -175,7 +185,9 @@ class RetrievalService:
                 )
                 if documents:
                     if reranking_model and retrival_method == RetrievalMethod.FULL_TEXT_SEARCH.value:
-                        data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
+                        data_post_processor = DataPostProcessor(str(dataset.tenant_id),
+                                                                RerankMode.RERANKING_MODEL.value,
+                                                                reranking_model, None, False)
                         all_documents.extend(data_post_processor.invoke(
                             query=query,
                             documents=documents,

+ 4 - 2
api/core/rag/datasource/vdb/qdrant/qdrant_vector.py

@@ -396,9 +396,11 @@ class QdrantVector(BaseVector):
         documents = []
         for result in results:
             if result:
-                documents.append(self._document_from_scored_point(
+                document = self._document_from_scored_point(
                     result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value
-                ))
+                )
+                document.metadata['vector'] = result.vector
+                documents.append(document)
 
         return documents
 

+ 0 - 0
api/core/rag/docstore/__init__.py


+ 8 - 0
api/core/rag/rerank/constants/rerank_mode.py

@@ -0,0 +1,8 @@
+from enum import Enum
+
+
+class RerankMode(Enum):
+
+    RERANKING_MODEL = 'reranking_model'
+    WEIGHTED_SCORE = 'weighted_score'
+

+ 23 - 0
api/core/rag/rerank/entity/weight.py

@@ -0,0 +1,23 @@
+from pydantic import BaseModel
+
+
+class VectorSetting(BaseModel):
+    vector_weight: float
+
+    embedding_provider_name: str
+
+    embedding_model_name: str
+
+
+class KeywordSetting(BaseModel):
+    keyword_weight: float
+
+
+class Weights(BaseModel):
+    """Model for weighted rerank."""
+
+    weight_type: str
+
+    vector_setting: VectorSetting
+
+    keyword_setting: KeywordSetting

+ 1 - 1
api/core/rag/rerank/rerank.py → api/core/rag/rerank/rerank_model.py

@@ -4,7 +4,7 @@ from core.model_manager import ModelInstance
 from core.rag.models.document import Document
 
 
-class RerankRunner:
+class RerankModelRunner:
     def __init__(self, rerank_model_instance: ModelInstance) -> None:
         self.rerank_model_instance = rerank_model_instance
 

+ 178 - 0
api/core/rag/rerank/weight_rerank.py

@@ -0,0 +1,178 @@
+import math
+from collections import Counter
+from typing import Optional
+
+import numpy as np
+
+from core.embedding.cached_embedding import CacheEmbedding
+from core.model_manager import ModelManager
+from core.model_runtime.entities.model_entities import ModelType
+from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
+from core.rag.models.document import Document
+from core.rag.rerank.entity.weight import VectorSetting, Weights
+
+
+class WeightRerankRunner:
+
+    def __init__(self, tenant_id: str, weights: Weights) -> None:
+        self.tenant_id = tenant_id
+        self.weights = weights
+
+    def run(self, query: str, documents: list[Document], score_threshold: Optional[float] = None,
+            top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]:
+        """
+        Run rerank model
+        :param query: search query
+        :param documents: documents for reranking
+        :param score_threshold: score threshold
+        :param top_n: top n
+        :param user: unique user id if needed
+
+        :return:
+        """
+        docs = []
+        doc_id = []
+        unique_documents = []
+        for document in documents:
+            if document.metadata['doc_id'] not in doc_id:
+                doc_id.append(document.metadata['doc_id'])
+                docs.append(document.page_content)
+                unique_documents.append(document)
+
+        documents = unique_documents
+
+        rerank_documents = []
+        query_scores = self._calculate_keyword_score(query, documents)
+
+        query_vector_scores = self._calculate_cosine(self.tenant_id, query, documents, self.weights.vector_setting)
+        for document, query_score, query_vector_score in zip(documents, query_scores, query_vector_scores):
+            # format document
+            score = self.weights.vector_setting.vector_weight * query_vector_score + \
+                    self.weights.keyword_setting.keyword_weight * query_score
+            if score_threshold and score < score_threshold:
+                continue
+            document.metadata['score'] = score
+            rerank_documents.append(document)
+        rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata['score'], reverse=True)
+        return rerank_documents[:top_n] if top_n else rerank_documents
+
+    def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]:
+        """
+        Calculate BM25 scores
+        :param query: search query
+        :param documents: documents for reranking
+
+        :return:
+        """
+        keyword_table_handler = JiebaKeywordTableHandler()
+        query_keywords = keyword_table_handler.extract_keywords(query, None)
+        documents_keywords = []
+        for document in documents:
+            # get the document keywords
+            document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
+            document.metadata['keywords'] = document_keywords
+            documents_keywords.append(document_keywords)
+
+        # Counter query keywords(TF)
+        query_keyword_counts = Counter(query_keywords)
+
+        # total documents
+        total_documents = len(documents)
+
+        # calculate all documents' keywords IDF
+        all_keywords = set()
+        for document_keywords in documents_keywords:
+            all_keywords.update(document_keywords)
+
+        keyword_idf = {}
+        for keyword in all_keywords:
+            # calculate include query keywords' documents
+            doc_count_containing_keyword = sum(1 for doc_keywords in documents_keywords if keyword in doc_keywords)
+            # IDF
+            keyword_idf[keyword] = math.log((1 + total_documents) / (1 + doc_count_containing_keyword)) + 1
+
+        query_tfidf = {}
+
+        for keyword, count in query_keyword_counts.items():
+            tf = count
+            idf = keyword_idf.get(keyword, 0)
+            query_tfidf[keyword] = tf * idf
+
+        # calculate all documents' TF-IDF
+        documents_tfidf = []
+        for document_keywords in documents_keywords:
+            document_keyword_counts = Counter(document_keywords)
+            document_tfidf = {}
+            for keyword, count in document_keyword_counts.items():
+                tf = count
+                idf = keyword_idf.get(keyword, 0)
+                document_tfidf[keyword] = tf * idf
+            documents_tfidf.append(document_tfidf)
+
+        def cosine_similarity(vec1, vec2):
+            intersection = set(vec1.keys()) & set(vec2.keys())
+            numerator = sum(vec1[x] * vec2[x] for x in intersection)
+
+            sum1 = sum(vec1[x] ** 2 for x in vec1.keys())
+            sum2 = sum(vec2[x] ** 2 for x in vec2.keys())
+            denominator = math.sqrt(sum1) * math.sqrt(sum2)
+
+            if not denominator:
+                return 0.0
+            else:
+                return float(numerator) / denominator
+
+        similarities = []
+        for document_tfidf in documents_tfidf:
+            similarity = cosine_similarity(query_tfidf, document_tfidf)
+            similarities.append(similarity)
+
+        # for idx, similarity in enumerate(similarities):
+        #     print(f"Document {idx + 1} similarity: {similarity}")
+
+        return similarities
+
+    def _calculate_cosine(self, tenant_id: str, query: str, documents: list[Document],
+                          vector_setting: VectorSetting) -> list[float]:
+        """
+        Calculate Cosine scores
+        :param query: search query
+        :param documents: documents for reranking
+
+        :return:
+        """
+        query_vector_scores = []
+
+        model_manager = ModelManager()
+
+        embedding_model = model_manager.get_model_instance(
+            tenant_id=tenant_id,
+            provider=vector_setting.embedding_provider_name,
+            model_type=ModelType.TEXT_EMBEDDING,
+            model=vector_setting.embedding_model_name
+
+        )
+        cache_embedding = CacheEmbedding(embedding_model)
+        query_vector = cache_embedding.embed_query(query)
+        for document in documents:
+            # calculate cosine similarity
+            if 'score' in document.metadata:
+                query_vector_scores.append(document.metadata['score'])
+            else:
+                content_vector = document.metadata['vector']
+                # transform to NumPy
+                vec1 = np.array(query_vector)
+                vec2 = np.array(document.metadata['vector'])
+
+                # calculate dot product
+                dot_product = np.dot(vec1, vec2)
+
+                # calculate norm
+                norm_vec1 = np.linalg.norm(vec1)
+                norm_vec2 = np.linalg.norm(vec2)
+
+                # calculate cosine similarity
+                cosine_sim = dot_product / (norm_vec1 * norm_vec2)
+                query_vector_scores.append(cosine_sim)
+
+        return query_vector_scores

+ 124 - 22
api/core/rag/retrieval/dataset_retrieval.py

@@ -1,4 +1,6 @@
+import math
 import threading
+from collections import Counter
 from typing import Optional, cast
 
 from flask import Flask, current_app
@@ -14,9 +16,10 @@ from core.model_runtime.entities.model_entities import ModelFeature, ModelType
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
 from core.ops.utils import measure_time
+from core.rag.data_post_processor.data_post_processor import DataPostProcessor
+from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
 from core.rag.datasource.retrieval_service import RetrievalService
 from core.rag.models.document import Document
-from core.rag.rerank.rerank import RerankRunner
 from core.rag.retrieval.retrival_methods import RetrievalMethod
 from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
 from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
@@ -132,8 +135,9 @@ class DatasetRetrieval:
                 app_id, tenant_id, user_id, user_from,
                 available_datasets, query, retrieve_config.top_k,
                 retrieve_config.score_threshold,
-                retrieve_config.reranking_model.get('reranking_provider_name'),
-                retrieve_config.reranking_model.get('reranking_model_name'),
+                retrieve_config.rerank_mode,
+                retrieve_config.reranking_model,
+                retrieve_config.weights,
                 message_id,
             )
 
@@ -272,7 +276,8 @@ class DatasetRetrieval:
                         retrival_method=retrival_method, dataset_id=dataset.id,
                         query=query,
                         top_k=top_k, score_threshold=score_threshold,
-                        reranking_model=reranking_model
+                        reranking_model=reranking_model,
+                        weights=retrieval_model_config.get('weights', None),
                     )
                 self._on_query(query, [dataset_id], app_id, user_from, user_id)
 
@@ -292,14 +297,18 @@ class DatasetRetrieval:
             query: str,
             top_k: int,
             score_threshold: float,
-            reranking_provider_name: str,
-            reranking_model_name: str,
+            reranking_mode: str,
+            reranking_model: Optional[dict] = None,
+            weights: Optional[dict] = None,
+            reranking_enable: bool = True,
             message_id: Optional[str] = None,
     ):
         threads = []
         all_documents = []
         dataset_ids = [dataset.id for dataset in available_datasets]
+        index_type = None
         for dataset in available_datasets:
+            index_type = dataset.indexing_technique
             retrieval_thread = threading.Thread(target=self._retriever, kwargs={
                 'flask_app': current_app._get_current_object(),
                 'dataset_id': dataset.id,
@@ -311,23 +320,24 @@ class DatasetRetrieval:
             retrieval_thread.start()
         for thread in threads:
             thread.join()
-        # do rerank for searched documents
-        model_manager = ModelManager()
-        rerank_model_instance = model_manager.get_model_instance(
-            tenant_id=tenant_id,
-            provider=reranking_provider_name,
-            model_type=ModelType.RERANK,
-            model=reranking_model_name
-        )
 
-        rerank_runner = RerankRunner(rerank_model_instance)
+        if reranking_enable:
+            # do rerank for searched documents
+            data_post_processor = DataPostProcessor(tenant_id, reranking_mode,
+                                                    reranking_model, weights, False)
 
-        with measure_time() as timer:
-            all_documents = rerank_runner.run(
-                query, all_documents,
-                score_threshold,
-                top_k
-            )
+            with measure_time() as timer:
+                all_documents = data_post_processor.invoke(
+                    query=query,
+                    documents=all_documents,
+                    score_threshold=score_threshold,
+                    top_n=top_k
+                )
+        else:
+            if index_type == "economy":
+                all_documents = self.calculate_keyword_score(query, all_documents, top_k)
+            elif index_type == "high_quality":
+                all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold)
         self._on_query(query, dataset_ids, app_id, user_from, user_id)
 
         if all_documents:
@@ -420,7 +430,8 @@ class DatasetRetrieval:
                                                           score_threshold=retrieval_model['score_threshold']
                                                           if retrieval_model['score_threshold_enabled'] else None,
                                                           reranking_model=retrieval_model['reranking_model']
-                                                          if retrieval_model['reranking_enable'] else None
+                                                          if retrieval_model['reranking_enable'] else None,
+                                                          weights=retrieval_model.get('weights', None),
                                                           )
 
                     all_documents.extend(documents)
@@ -513,3 +524,94 @@ class DatasetRetrieval:
             tools.append(tool)
 
         return tools
+
+    def calculate_keyword_score(self, query: str, documents: list[Document], top_k: int) -> list[Document]:
+        """
+        Calculate keywords scores
+        :param query: search query
+        :param documents: documents for reranking
+
+        :return:
+        """
+        keyword_table_handler = JiebaKeywordTableHandler()
+        query_keywords = keyword_table_handler.extract_keywords(query, None)
+        documents_keywords = []
+        for document in documents:
+            # get the document keywords
+            document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
+            document.metadata['keywords'] = document_keywords
+            documents_keywords.append(document_keywords)
+
+        # Counter query keywords(TF)
+        query_keyword_counts = Counter(query_keywords)
+
+        # total documents
+        total_documents = len(documents)
+
+        # calculate all documents' keywords IDF
+        all_keywords = set()
+        for document_keywords in documents_keywords:
+            all_keywords.update(document_keywords)
+
+        keyword_idf = {}
+        for keyword in all_keywords:
+            # calculate include query keywords' documents
+            doc_count_containing_keyword = sum(1 for doc_keywords in documents_keywords if keyword in doc_keywords)
+            # IDF
+            keyword_idf[keyword] = math.log((1 + total_documents) / (1 + doc_count_containing_keyword)) + 1
+
+        query_tfidf = {}
+
+        for keyword, count in query_keyword_counts.items():
+            tf = count
+            idf = keyword_idf.get(keyword, 0)
+            query_tfidf[keyword] = tf * idf
+
+        # calculate all documents' TF-IDF
+        documents_tfidf = []
+        for document_keywords in documents_keywords:
+            document_keyword_counts = Counter(document_keywords)
+            document_tfidf = {}
+            for keyword, count in document_keyword_counts.items():
+                tf = count
+                idf = keyword_idf.get(keyword, 0)
+                document_tfidf[keyword] = tf * idf
+            documents_tfidf.append(document_tfidf)
+
+        def cosine_similarity(vec1, vec2):
+            intersection = set(vec1.keys()) & set(vec2.keys())
+            numerator = sum(vec1[x] * vec2[x] for x in intersection)
+
+            sum1 = sum(vec1[x] ** 2 for x in vec1.keys())
+            sum2 = sum(vec2[x] ** 2 for x in vec2.keys())
+            denominator = math.sqrt(sum1) * math.sqrt(sum2)
+
+            if not denominator:
+                return 0.0
+            else:
+                return float(numerator) / denominator
+
+        similarities = []
+        for document_tfidf in documents_tfidf:
+            similarity = cosine_similarity(query_tfidf, document_tfidf)
+            similarities.append(similarity)
+
+        for document, score in zip(documents, similarities):
+            # format document
+            document.metadata['score'] = score
+        documents = sorted(documents, key=lambda x: x.metadata['score'], reverse=True)
+        return documents[:top_k] if top_k else documents
+
+    def calculate_vector_score(self, all_documents: list[Document],
+                               top_k: int, score_threshold: float) -> list[Document]:
+        filter_documents = []
+        for document in all_documents:
+            if document.metadata['score'] >= score_threshold:
+                filter_documents.append(document)
+        if not filter_documents:
+            return []
+        filter_documents = sorted(filter_documents, key=lambda x: x.metadata['score'], reverse=True)
+        return filter_documents[:top_k] if top_k else filter_documents
+
+
+

+ 0 - 0
api/core/rag/splitter/__init__.py


+ 4 - 3
api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py

@@ -7,7 +7,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
 from core.rag.datasource.retrieval_service import RetrievalService
-from core.rag.rerank.rerank import RerankRunner
+from core.rag.rerank.rerank_model import RerankModelRunner
 from core.rag.retrieval.retrival_methods import RetrievalMethod
 from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
 from extensions.ext_database import db
@@ -72,7 +72,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
             model=self.reranking_model_name
         )
 
-        rerank_runner = RerankRunner(rerank_model_instance)
+        rerank_runner = RerankModelRunner(rerank_model_instance)
         all_documents = rerank_runner.run(query, all_documents, self.score_threshold, self.top_k)
 
         for hit_callback in self.hit_callbacks:
@@ -180,7 +180,8 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
                                                           score_threshold=retrieval_model['score_threshold']
                                                           if retrieval_model['score_threshold_enabled'] else None,
                                                           reranking_model=retrieval_model['reranking_model']
-                                                          if retrieval_model['reranking_enable'] else None
+                                                          if retrieval_model['reranking_enable'] else None,
+                                                          weights=retrieval_model.get('weights', None),
                                                           )
 
                     all_documents.extend(documents)

+ 2 - 1
api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py

@@ -78,7 +78,8 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
                                                       score_threshold=retrieval_model['score_threshold']
                                                       if retrieval_model['score_threshold_enabled'] else None,
                                                       reranking_model=retrieval_model['reranking_model']
-                                                      if retrieval_model['reranking_enable'] else None
+                                                      if retrieval_model['reranking_enable'] else None,
+                                                      weights=retrieval_model.get('weights', None),
                                                       )
             else:
                 documents = []

+ 28 - 0
api/core/workflow/nodes/knowledge_retrieval/entities.py

@@ -13,13 +13,41 @@ class RerankingModelConfig(BaseModel):
     model: str
 
 
+class VectorSetting(BaseModel):
+    """
+    Vector Setting.
+    """
+    vector_weight: float
+    embedding_provider_name: str
+    embedding_model_name: str
+
+
+class KeywordSetting(BaseModel):
+    """
+    Keyword Setting.
+    """
+    keyword_weight: float
+
+
+class WeightedScoreConfig(BaseModel):
+    """
+    Weighted score Config.
+    """
+    weight_type: str
+    vector_setting: VectorSetting
+    keyword_setting: KeywordSetting
+
+
 class MultipleRetrievalConfig(BaseModel):
     """
     Multiple Retrieval Config.
     """
     top_k: int
     score_threshold: Optional[float] = None
+    reranking_mode: str = 'reranking_model'
+    reranking_enable: bool = True
     reranking_model: RerankingModelConfig
+    weights: WeightedScoreConfig
 
 
 class ModelConfig(BaseModel):

+ 27 - 2
api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py

@@ -138,13 +138,38 @@ class KnowledgeRetrievalNode(BaseNode):
                     planning_strategy=planning_strategy
                 )
         elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
+            if node_data.multiple_retrieval_config.reranking_mode == 'reranking_model':
+                reranking_model = {
+                    'reranking_provider_name': node_data.multiple_retrieval_config.reranking_model['provider'],
+                    'reranking_model_name': node_data.multiple_retrieval_config.reranking_model['name']
+                }
+                weights = None
+            elif node_data.multiple_retrieval_config.reranking_mode == 'weighted_score':
+                reranking_model = None
+                weights = {
+                    'weight_type': node_data.multiple_retrieval_config.weights.weight_type,
+                    'vector_setting': {
+                        "vector_weight": node_data.multiple_retrieval_config.weights.vector_setting.vector_weight,
+                        "embedding_provider_name": node_data.multiple_retrieval_config.weights.vector_setting.embedding_provider_name,
+                        "embedding_model_name": node_data.multiple_retrieval_config.weights.vector_setting.embedding_model_name,
+                    },
+                    'keyword_setting': {
+                        "keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight
+                    }
+                }
+            else:
+                reranking_model = None
+                weights = None
             all_documents = dataset_retrieval.multiple_retrieve(self.app_id, self.tenant_id, self.user_id,
                                                                 self.user_from.value,
                                                                 available_datasets, query,
                                                                 node_data.multiple_retrieval_config.top_k,
                                                                 node_data.multiple_retrieval_config.score_threshold,
-                                                                node_data.multiple_retrieval_config.reranking_model.provider,
-                                                                node_data.multiple_retrieval_config.reranking_model.model)
+                                                                node_data.multiple_retrieval_config.reranking_mode,
+                                                                reranking_model,
+                                                                weights,
+                                                                node_data.multiple_retrieval_config.reranking_enable,
+                                                                )
 
         context_list = []
         if all_documents:

+ 18 - 0
api/fields/dataset_fields.py

@@ -18,10 +18,28 @@ reranking_model_fields = {
     'reranking_model_name': fields.String
 }
 
+keyword_setting_fields = {
+    'keyword_weight': fields.Float
+}
+
+vector_setting_fields = {
+    'vector_weight': fields.Float,
+    'embedding_model_name': fields.String,
+    'embedding_provider_name': fields.String,
+}
+
+weighted_score_fields = {
+    'weight_type': fields.String,
+    'keyword_setting': fields.Nested(keyword_setting_fields),
+    'vector_setting': fields.Nested(vector_setting_fields),
+}
+
 dataset_retrieval_model_fields = {
     'search_method': fields.String,
     'reranking_enable': fields.Boolean,
+    'reranking_mode': fields.String,
     'reranking_model': fields.Nested(reranking_model_fields),
+    'weights': fields.Nested(weighted_score_fields, allow_null=True),
     'top_k': fields.Integer,
     'score_threshold_enabled': fields.Boolean,
     'score_threshold': fields.Float

+ 3 - 1
api/models/model.py

@@ -328,7 +328,9 @@ class AppModelConfig(db.Model):
                 return {'retrieval_model': 'single'}
             else:
                 return dataset_configs
-        return {'retrieval_model': 'single'}
+        return {
+                'retrieval_model': 'multiple',
+            }
 
     @property
     def file_upload_dict(self) -> dict:

+ 154 - 55
api/poetry.lock

@@ -1242,40 +1242,36 @@ files = [
 
 [[package]]
 name = "chroma-hnswlib"
-version = "0.7.6"
+version = "0.7.3"
 description = "Chromas fork of hnswlib"
 optional = false
 python-versions = "*"
 files = [
-    {file = "chroma_hnswlib-0.7.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f35192fbbeadc8c0633f0a69c3d3e9f1a4eab3a46b65458bbcbcabdd9e895c36"},
-    {file = "chroma_hnswlib-0.7.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6f007b608c96362b8f0c8b6b2ac94f67f83fcbabd857c378ae82007ec92f4d82"},
-    {file = "chroma_hnswlib-0.7.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:456fd88fa0d14e6b385358515aef69fc89b3c2191706fd9aee62087b62aad09c"},
-    {file = "chroma_hnswlib-0.7.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5dfaae825499c2beaa3b75a12d7ec713b64226df72a5c4097203e3ed532680da"},
-    {file = "chroma_hnswlib-0.7.6-cp310-cp310-win_amd64.whl", hash = "sha256:2487201982241fb1581be26524145092c95902cb09fc2646ccfbc407de3328ec"},
-    {file = "chroma_hnswlib-0.7.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:81181d54a2b1e4727369486a631f977ffc53c5533d26e3d366dda243fb0998ca"},
-    {file = "chroma_hnswlib-0.7.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4b4ab4e11f1083dd0a11ee4f0e0b183ca9f0f2ed63ededba1935b13ce2b3606f"},
-    {file = "chroma_hnswlib-0.7.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:53db45cd9173d95b4b0bdccb4dbff4c54a42b51420599c32267f3abbeb795170"},
-    {file = "chroma_hnswlib-0.7.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c093f07a010b499c00a15bc9376036ee4800d335360570b14f7fe92badcdcf9"},
-    {file = "chroma_hnswlib-0.7.6-cp311-cp311-win_amd64.whl", hash = "sha256:0540b0ac96e47d0aa39e88ea4714358ae05d64bbe6bf33c52f316c664190a6a3"},
-    {file = "chroma_hnswlib-0.7.6-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e87e9b616c281bfbe748d01705817c71211613c3b063021f7ed5e47173556cb7"},
-    {file = "chroma_hnswlib-0.7.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ec5ca25bc7b66d2ecbf14502b5729cde25f70945d22f2aaf523c2d747ea68912"},
-    {file = "chroma_hnswlib-0.7.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:305ae491de9d5f3c51e8bd52d84fdf2545a4a2bc7af49765cda286b7bb30b1d4"},
-    {file = "chroma_hnswlib-0.7.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:822ede968d25a2c88823ca078a58f92c9b5c4142e38c7c8b4c48178894a0a3c5"},
-    {file = "chroma_hnswlib-0.7.6-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2fe6ea949047beed19a94b33f41fe882a691e58b70c55fdaa90274ae78be046f"},
-    {file = "chroma_hnswlib-0.7.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:feceff971e2a2728c9ddd862a9dd6eb9f638377ad98438876c9aeac96c9482f5"},
-    {file = "chroma_hnswlib-0.7.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb0633b60e00a2b92314d0bf5bbc0da3d3320be72c7e3f4a9b19f4609dc2b2ab"},
-    {file = "chroma_hnswlib-0.7.6-cp37-cp37m-win_amd64.whl", hash = "sha256:a566abe32fab42291f766d667bdbfa234a7f457dcbd2ba19948b7a978c8ca624"},
-    {file = "chroma_hnswlib-0.7.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6be47853d9a58dedcfa90fc846af202b071f028bbafe1d8711bf64fe5a7f6111"},
-    {file = "chroma_hnswlib-0.7.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3a7af35bdd39a88bffa49f9bb4bf4f9040b684514a024435a1ef5cdff980579d"},
-    {file = "chroma_hnswlib-0.7.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a53b1f1551f2b5ad94eb610207bde1bb476245fc5097a2bec2b476c653c58bde"},
-    {file = "chroma_hnswlib-0.7.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3085402958dbdc9ff5626ae58d696948e715aef88c86d1e3f9285a88f1afd3bc"},
-    {file = "chroma_hnswlib-0.7.6-cp38-cp38-win_amd64.whl", hash = "sha256:77326f658a15adfb806a16543f7db7c45f06fd787d699e643642d6bde8ed49c4"},
-    {file = "chroma_hnswlib-0.7.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:93b056ab4e25adab861dfef21e1d2a2756b18be5bc9c292aa252fa12bb44e6ae"},
-    {file = "chroma_hnswlib-0.7.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fe91f018b30452c16c811fd6c8ede01f84e5a9f3c23e0758775e57f1c3778871"},
-    {file = "chroma_hnswlib-0.7.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e6c0e627476f0f4d9e153420d36042dd9c6c3671cfd1fe511c0253e38c2a1039"},
-    {file = "chroma_hnswlib-0.7.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e9796a4536b7de6c6d76a792ba03e08f5aaa53e97e052709568e50b4d20c04f"},
-    {file = "chroma_hnswlib-0.7.6-cp39-cp39-win_amd64.whl", hash = "sha256:d30e2db08e7ffdcc415bd072883a322de5995eb6ec28a8f8c054103bbd3ec1e0"},
-    {file = "chroma_hnswlib-0.7.6.tar.gz", hash = "sha256:4dce282543039681160259d29fcde6151cc9106c6461e0485f57cdccd83059b7"},
+    {file = "chroma-hnswlib-0.7.3.tar.gz", hash = "sha256:b6137bedde49fffda6af93b0297fe00429fc61e5a072b1ed9377f909ed95a932"},
+    {file = "chroma_hnswlib-0.7.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:59d6a7c6f863c67aeb23e79a64001d537060b6995c3eca9a06e349ff7b0998ca"},
+    {file = "chroma_hnswlib-0.7.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d71a3f4f232f537b6152947006bd32bc1629a8686df22fd97777b70f416c127a"},
+    {file = "chroma_hnswlib-0.7.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c92dc1ebe062188e53970ba13f6b07e0ae32e64c9770eb7f7ffa83f149d4210"},
+    {file = "chroma_hnswlib-0.7.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49da700a6656fed8753f68d44b8cc8ae46efc99fc8a22a6d970dc1697f49b403"},
+    {file = "chroma_hnswlib-0.7.3-cp310-cp310-win_amd64.whl", hash = "sha256:108bc4c293d819b56476d8f7865803cb03afd6ca128a2a04d678fffc139af029"},
+    {file = "chroma_hnswlib-0.7.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:11e7ca93fb8192214ac2b9c0943641ac0daf8f9d4591bb7b73be808a83835667"},
+    {file = "chroma_hnswlib-0.7.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6f552e4d23edc06cdeb553cdc757d2fe190cdeb10d43093d6a3319f8d4bf1c6b"},
+    {file = "chroma_hnswlib-0.7.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f96f4d5699e486eb1fb95849fe35ab79ab0901265805be7e60f4eaa83ce263ec"},
+    {file = "chroma_hnswlib-0.7.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:368e57fe9ebae05ee5844840fa588028a023d1182b0cfdb1d13f607c9ea05756"},
+    {file = "chroma_hnswlib-0.7.3-cp311-cp311-win_amd64.whl", hash = "sha256:b7dca27b8896b494456db0fd705b689ac6b73af78e186eb6a42fea2de4f71c6f"},
+    {file = "chroma_hnswlib-0.7.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:70f897dc6218afa1d99f43a9ad5eb82f392df31f57ff514ccf4eeadecd62f544"},
+    {file = "chroma_hnswlib-0.7.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5aef10b4952708f5a1381c124a29aead0c356f8d7d6e0b520b778aaa62a356f4"},
+    {file = "chroma_hnswlib-0.7.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ee2d8d1529fca3898d512079144ec3e28a81d9c17e15e0ea4665697a7923253"},
+    {file = "chroma_hnswlib-0.7.3-cp37-cp37m-win_amd64.whl", hash = "sha256:a4021a70e898783cd6f26e00008b494c6249a7babe8774e90ce4766dd288c8ba"},
+    {file = "chroma_hnswlib-0.7.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:a8f61fa1d417fda848e3ba06c07671f14806a2585272b175ba47501b066fe6b1"},
+    {file = "chroma_hnswlib-0.7.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d7563be58bc98e8f0866907368e22ae218d6060601b79c42f59af4eccbbd2e0a"},
+    {file = "chroma_hnswlib-0.7.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:51b8d411486ee70d7b66ec08cc8b9b6620116b650df9c19076d2d8b6ce2ae914"},
+    {file = "chroma_hnswlib-0.7.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d706782b628e4f43f1b8a81e9120ac486837fbd9bcb8ced70fe0d9b95c72d77"},
+    {file = "chroma_hnswlib-0.7.3-cp38-cp38-win_amd64.whl", hash = "sha256:54f053dedc0e3ba657f05fec6e73dd541bc5db5b09aa8bc146466ffb734bdc86"},
+    {file = "chroma_hnswlib-0.7.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e607c5a71c610a73167a517062d302c0827ccdd6e259af6e4869a5c1306ffb5d"},
+    {file = "chroma_hnswlib-0.7.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c2358a795870156af6761890f9eb5ca8cade57eb10c5f046fe94dae1faa04b9e"},
+    {file = "chroma_hnswlib-0.7.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7cea425df2e6b8a5e201fff0d922a1cc1d165b3cfe762b1408075723c8892218"},
+    {file = "chroma_hnswlib-0.7.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:454df3dd3e97aa784fba7cf888ad191e0087eef0fd8c70daf28b753b3b591170"},
+    {file = "chroma_hnswlib-0.7.3-cp39-cp39-win_amd64.whl", hash = "sha256:df587d15007ca701c6de0ee7d5585dd5e976b7edd2b30ac72bc376b3c3f85882"},
 ]
 
 [package.dependencies]
@@ -1283,19 +1279,19 @@ numpy = "*"
 
 [[package]]
 name = "chromadb"
-version = "0.5.5"
+version = "0.5.1"
 description = "Chroma."
 optional = false
 python-versions = ">=3.8"
 files = [
-    {file = "chromadb-0.5.5-py3-none-any.whl", hash = "sha256:2a5a4b84cb0fc32b380e193be68cdbadf3d9f77dbbf141649be9886e42910ddd"},
-    {file = "chromadb-0.5.5.tar.gz", hash = "sha256:84f4bfee320fb4912cbeb4d738f01690891e9894f0ba81f39ee02867102a1c4d"},
+    {file = "chromadb-0.5.1-py3-none-any.whl", hash = "sha256:61f1f75a672b6edce7f1c8875c67e2aaaaf130dc1c1684431fbc42ad7240d01d"},
+    {file = "chromadb-0.5.1.tar.gz", hash = "sha256:e2b2b6a34c2a949bedcaa42fa7775f40c7f6667848fc8094dcbf97fc0d30bee7"},
 ]
 
 [package.dependencies]
 bcrypt = ">=4.0.1"
 build = ">=1.0.3"
-chroma-hnswlib = "0.7.6"
+chroma-hnswlib = "0.7.3"
 fastapi = ">=0.95.2"
 grpcio = ">=1.58.0"
 httpx = ">=0.27.0"
@@ -1314,6 +1310,7 @@ posthog = ">=2.4.0"
 pydantic = ">=1.9"
 pypika = ">=0.48.9"
 PyYAML = ">=6.0.0"
+requests = ">=2.28"
 tenacity = ">=8.2.3"
 tokenizers = ">=0.13.2"
 tqdm = ">=4.65.0"
@@ -6081,19 +6078,6 @@ files = [
     {file = "pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7"},
     {file = "pyarrow-17.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:af5ff82a04b2171415f1410cff7ebb79861afc5dae50be73ce06d6e870615204"},
     {file = "pyarrow-17.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:edca18eaca89cd6382dfbcff3dd2d87633433043650c07375d095cd3517561d8"},
-    {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c7916bff914ac5d4a8fe25b7a25e432ff921e72f6f2b7547d1e325c1ad9d155"},
-    {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f553ca691b9e94b202ff741bdd40f6ccb70cdd5fbf65c187af132f1317de6145"},
-    {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0cdb0e627c86c373205a2f94a510ac4376fdc523f8bb36beab2e7f204416163c"},
-    {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d7d192305d9d8bc9082d10f361fc70a73590a4c65cf31c3e6926cd72b76bc35c"},
-    {file = "pyarrow-17.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:02dae06ce212d8b3244dd3e7d12d9c4d3046945a5933d28026598e9dbbda1fca"},
-    {file = "pyarrow-17.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:13d7a460b412f31e4c0efa1148e1d29bdf18ad1411eb6757d38f8fbdcc8645fb"},
-    {file = "pyarrow-17.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9b564a51fbccfab5a04a80453e5ac6c9954a9c5ef2890d1bcf63741909c3f8df"},
-    {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32503827abbc5aadedfa235f5ece8c4f8f8b0a3cf01066bc8d29de7539532687"},
-    {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a155acc7f154b9ffcc85497509bcd0d43efb80d6f733b0dc3bb14e281f131c8b"},
-    {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:dec8d129254d0188a49f8a1fc99e0560dc1b85f60af729f47de4046015f9b0a5"},
-    {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:a48ddf5c3c6a6c505904545c25a4ae13646ae1f8ba703c4df4a1bfe4f4006bda"},
-    {file = "pyarrow-17.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:42bf93249a083aca230ba7e2786c5f673507fa97bbd9725a1e2754715151a204"},
-    {file = "pyarrow-17.0.0.tar.gz", hash = "sha256:4beca9521ed2c0921c1023e68d097d0299b62c362639ea315572a58f3f50fd28"},
 ]
 
 [package.dependencies]
@@ -6970,6 +6954,23 @@ maintainer = ["zest.releaser[recommended]"]
 pil = ["pillow (>=9.1.0)"]
 test = ["coverage", "pytest"]
 
+[[package]]
+name = "rank-bm25"
+version = "0.2.2"
+description = "Various BM25 algorithms for document ranking"
+optional = false
+python-versions = "*"
+files = [
+    {file = "rank_bm25-0.2.2-py3-none-any.whl", hash = "sha256:7bd4a95571adadfc271746fa146a4bcfd89c0cf731e49c3d1ad863290adbe8ae"},
+    {file = "rank_bm25-0.2.2.tar.gz", hash = "sha256:096ccef76f8188563419aaf384a02f0ea459503fdf77901378d4fd9d87e5e51d"},
+]
+
+[package.dependencies]
+numpy = "*"
+
+[package.extras]
+dev = ["pytest"]
+
 [[package]]
 name = "rapidfuzz"
 version = "3.9.4"
@@ -7503,6 +7504,93 @@ tensorflow = ["safetensors[numpy]", "tensorflow (>=2.11.0)"]
 testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools-rust (>=1.5.2)"]
 torch = ["safetensors[numpy]", "torch (>=1.10)"]
 
+[[package]]
+name = "scikit-learn"
+version = "1.5.1"
+description = "A set of python modules for machine learning and data mining"
+optional = false
+python-versions = ">=3.9"
+files = [
+    {file = "scikit_learn-1.5.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:781586c414f8cc58e71da4f3d7af311e0505a683e112f2f62919e3019abd3745"},
+    {file = "scikit_learn-1.5.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:f5b213bc29cc30a89a3130393b0e39c847a15d769d6e59539cd86b75d276b1a7"},
+    {file = "scikit_learn-1.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1ff4ba34c2abff5ec59c803ed1d97d61b036f659a17f55be102679e88f926fac"},
+    {file = "scikit_learn-1.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:161808750c267b77b4a9603cf9c93579c7a74ba8486b1336034c2f1579546d21"},
+    {file = "scikit_learn-1.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:10e49170691514a94bb2e03787aa921b82dbc507a4ea1f20fd95557862c98dc1"},
+    {file = "scikit_learn-1.5.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:154297ee43c0b83af12464adeab378dee2d0a700ccd03979e2b821e7dd7cc1c2"},
+    {file = "scikit_learn-1.5.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:b5e865e9bd59396220de49cb4a57b17016256637c61b4c5cc81aaf16bc123bbe"},
+    {file = "scikit_learn-1.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:909144d50f367a513cee6090873ae582dba019cb3fca063b38054fa42704c3a4"},
+    {file = "scikit_learn-1.5.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:689b6f74b2c880276e365fe84fe4f1befd6a774f016339c65655eaff12e10cbf"},
+    {file = "scikit_learn-1.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:9a07f90846313a7639af6a019d849ff72baadfa4c74c778821ae0fad07b7275b"},
+    {file = "scikit_learn-1.5.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5944ce1faada31c55fb2ba20a5346b88e36811aab504ccafb9f0339e9f780395"},
+    {file = "scikit_learn-1.5.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:0828673c5b520e879f2af6a9e99eee0eefea69a2188be1ca68a6121b809055c1"},
+    {file = "scikit_learn-1.5.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:508907e5f81390e16d754e8815f7497e52139162fd69c4fdbd2dfa5d6cc88915"},
+    {file = "scikit_learn-1.5.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97625f217c5c0c5d0505fa2af28ae424bd37949bb2f16ace3ff5f2f81fb4498b"},
+    {file = "scikit_learn-1.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:da3f404e9e284d2b0a157e1b56b6566a34eb2798205cba35a211df3296ab7a74"},
+    {file = "scikit_learn-1.5.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:88e0672c7ac21eb149d409c74cc29f1d611d5158175846e7a9c2427bd12b3956"},
+    {file = "scikit_learn-1.5.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:7b073a27797a283187a4ef4ee149959defc350b46cbf63a84d8514fe16b69855"},
+    {file = "scikit_learn-1.5.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b59e3e62d2be870e5c74af4e793293753565c7383ae82943b83383fdcf5cc5c1"},
+    {file = "scikit_learn-1.5.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1bd8d3a19d4bd6dc5a7d4f358c8c3a60934dc058f363c34c0ac1e9e12a31421d"},
+    {file = "scikit_learn-1.5.1-cp39-cp39-win_amd64.whl", hash = "sha256:5f57428de0c900a98389c4a433d4a3cf89de979b3aa24d1c1d251802aa15e44d"},
+    {file = "scikit_learn-1.5.1.tar.gz", hash = "sha256:0ea5d40c0e3951df445721927448755d3fe1d80833b0b7308ebff5d2a45e6414"},
+]
+
+[package.dependencies]
+joblib = ">=1.2.0"
+numpy = ">=1.19.5"
+scipy = ">=1.6.0"
+threadpoolctl = ">=3.1.0"
+
+[package.extras]
+benchmark = ["matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "pandas (>=1.1.5)"]
+build = ["cython (>=3.0.10)", "meson-python (>=0.16.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)"]
+docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "polars (>=0.20.23)", "pooch (>=1.6.0)", "pydata-sphinx-theme (>=0.15.3)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)", "sphinx (>=7.3.7)", "sphinx-copybutton (>=0.5.2)", "sphinx-design (>=0.5.0)", "sphinx-gallery (>=0.16.0)", "sphinx-prompt (>=1.4.0)", "sphinx-remove-toctrees (>=1.0.0.post1)", "sphinxcontrib-sass (>=0.3.4)", "sphinxext-opengraph (>=0.9.1)"]
+examples = ["matplotlib (>=3.3.4)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)"]
+install = ["joblib (>=1.2.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)", "threadpoolctl (>=3.1.0)"]
+maintenance = ["conda-lock (==2.5.6)"]
+tests = ["black (>=24.3.0)", "matplotlib (>=3.3.4)", "mypy (>=1.9)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "polars (>=0.20.23)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pyarrow (>=12.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.2.1)", "scikit-image (>=0.17.2)"]
+
+[[package]]
+name = "scipy"
+version = "1.14.0"
+description = "Fundamental algorithms for scientific computing in Python"
+optional = false
+python-versions = ">=3.10"
+files = [
+    {file = "scipy-1.14.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7e911933d54ead4d557c02402710c2396529540b81dd554fc1ba270eb7308484"},
+    {file = "scipy-1.14.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:687af0a35462402dd851726295c1a5ae5f987bd6e9026f52e9505994e2f84ef6"},
+    {file = "scipy-1.14.0-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:07e179dc0205a50721022344fb85074f772eadbda1e1b3eecdc483f8033709b7"},
+    {file = "scipy-1.14.0-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:6a9c9a9b226d9a21e0a208bdb024c3982932e43811b62d202aaf1bb59af264b1"},
+    {file = "scipy-1.14.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:076c27284c768b84a45dcf2e914d4000aac537da74236a0d45d82c6fa4b7b3c0"},
+    {file = "scipy-1.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42470ea0195336df319741e230626b6225a740fd9dce9642ca13e98f667047c0"},
+    {file = "scipy-1.14.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:176c6f0d0470a32f1b2efaf40c3d37a24876cebf447498a4cefb947a79c21e9d"},
+    {file = "scipy-1.14.0-cp310-cp310-win_amd64.whl", hash = "sha256:ad36af9626d27a4326c8e884917b7ec321d8a1841cd6dacc67d2a9e90c2f0359"},
+    {file = "scipy-1.14.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6d056a8709ccda6cf36cdd2eac597d13bc03dba38360f418560a93050c76a16e"},
+    {file = "scipy-1.14.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:f0a50da861a7ec4573b7c716b2ebdcdf142b66b756a0d392c236ae568b3a93fb"},
+    {file = "scipy-1.14.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:94c164a9e2498e68308e6e148646e486d979f7fcdb8b4cf34b5441894bdb9caf"},
+    {file = "scipy-1.14.0-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:a7d46c3e0aea5c064e734c3eac5cf9eb1f8c4ceee756262f2c7327c4c2691c86"},
+    {file = "scipy-1.14.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9eee2989868e274aae26125345584254d97c56194c072ed96cb433f32f692ed8"},
+    {file = "scipy-1.14.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e3154691b9f7ed73778d746da2df67a19d046a6c8087c8b385bc4cdb2cfca74"},
+    {file = "scipy-1.14.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c40003d880f39c11c1edbae8144e3813904b10514cd3d3d00c277ae996488cdb"},
+    {file = "scipy-1.14.0-cp311-cp311-win_amd64.whl", hash = "sha256:5b083c8940028bb7e0b4172acafda6df762da1927b9091f9611b0bcd8676f2bc"},
+    {file = "scipy-1.14.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:bff2438ea1330e06e53c424893ec0072640dac00f29c6a43a575cbae4c99b2b9"},
+    {file = "scipy-1.14.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:bbc0471b5f22c11c389075d091d3885693fd3f5e9a54ce051b46308bc787e5d4"},
+    {file = "scipy-1.14.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:64b2ff514a98cf2bb734a9f90d32dc89dc6ad4a4a36a312cd0d6327170339eb0"},
+    {file = "scipy-1.14.0-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:7d3da42fbbbb860211a811782504f38ae7aaec9de8764a9bef6b262de7a2b50f"},
+    {file = "scipy-1.14.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d91db2c41dd6c20646af280355d41dfa1ec7eead235642178bd57635a3f82209"},
+    {file = "scipy-1.14.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a01cc03bcdc777c9da3cfdcc74b5a75caffb48a6c39c8450a9a05f82c4250a14"},
+    {file = "scipy-1.14.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:65df4da3c12a2bb9ad52b86b4dcf46813e869afb006e58be0f516bc370165159"},
+    {file = "scipy-1.14.0-cp312-cp312-win_amd64.whl", hash = "sha256:4c4161597c75043f7154238ef419c29a64ac4a7c889d588ea77690ac4d0d9b20"},
+    {file = "scipy-1.14.0.tar.gz", hash = "sha256:b5923f48cb840380f9854339176ef21763118a7300a88203ccd0bdd26e58527b"},
+]
+
+[package.dependencies]
+numpy = ">=1.23.5,<2.3"
+
+[package.extras]
+dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy (==1.10.0)", "pycodestyle", "pydevtool", "rich-click", "ruff (>=0.0.292)", "types-psutil", "typing_extensions"]
+doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.13.1)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0)", "sphinx-design (>=0.4.0)"]
+test = ["Cython", "array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"]
+
 [[package]]
 name = "sentry-sdk"
 version = "1.44.1"
@@ -7882,13 +7970,13 @@ test = ["pytest", "tornado (>=4.5)", "typeguard"]
 
 [[package]]
 name = "tencentcloud-sdk-python-common"
-version = "3.0.1195"
+version = "3.0.1196"
 description = "Tencent Cloud Common SDK for Python"
 optional = false
 python-versions = "*"
 files = [
-    {file = "tencentcloud-sdk-python-common-3.0.1195.tar.gz", hash = "sha256:12acd48a14af327c39edf216bf29c80936f47e351675c4060d695902668ef98d"},
-    {file = "tencentcloud_sdk_python_common-3.0.1195-py2.py3-none-any.whl", hash = "sha256:41c21012176f3d5f4e1ba7fbb2cdcac9c4cd62e57e9447abd0076f36146c75f6"},
+    {file = "tencentcloud-sdk-python-common-3.0.1196.tar.gz", hash = "sha256:a8acd14f7480987ff0fd1d961ad934b2b7533ab1937d7e3adb74d95dc49954bd"},
+    {file = "tencentcloud_sdk_python_common-3.0.1196-py2.py3-none-any.whl", hash = "sha256:5ed438bc3e2818ca8e84b3896aaa2746798fba981bd94b27528eb36efa5b4a30"},
 ]
 
 [package.dependencies]
@@ -7896,17 +7984,28 @@ requests = ">=2.16.0"
 
 [[package]]
 name = "tencentcloud-sdk-python-hunyuan"
-version = "3.0.1195"
+version = "3.0.1196"
 description = "Tencent Cloud Hunyuan SDK for Python"
 optional = false
 python-versions = "*"
 files = [
-    {file = "tencentcloud-sdk-python-hunyuan-3.0.1195.tar.gz", hash = "sha256:c22c21fcef7465eb845b694f7901311db78da45ec1e8ea80ec6549f248cb10a7"},
-    {file = "tencentcloud_sdk_python_hunyuan-3.0.1195-py2.py3-none-any.whl", hash = "sha256:729d19889ebe19258b84f10c950971c07c1be665c72608e475b442dc2b79e0c0"},
+    {file = "tencentcloud-sdk-python-hunyuan-3.0.1196.tar.gz", hash = "sha256:ced26497ae5f1b8fcc6cbd12238109274251e82fa1cfedfd6700df776306a36c"},
+    {file = "tencentcloud_sdk_python_hunyuan-3.0.1196-py2.py3-none-any.whl", hash = "sha256:d18a19cffeaf4ff8a60670dc2bdb644f3d7ae6a51c30d21b50ded24a9c542248"},
 ]
 
 [package.dependencies]
-tencentcloud-sdk-python-common = "3.0.1195"
+tencentcloud-sdk-python-common = "3.0.1196"
+
+[[package]]
+name = "threadpoolctl"
+version = "3.5.0"
+description = "threadpoolctl"
+optional = false
+python-versions = ">=3.8"
+files = [
+    {file = "threadpoolctl-3.5.0-py3-none-any.whl", hash = "sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467"},
+    {file = "threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107"},
+]
 
 [[package]]
 name = "tidb-vector"
@@ -9444,4 +9543,4 @@ cffi = ["cffi (>=1.11)"]
 [metadata]
 lock-version = "2.0"
 python-versions = "^3.10"
-content-hash = "88dace04a79b56b994195cc627be27336083524724b3520e13bd50fe211b32df"
+content-hash = "6b7d8b1333ae9c71ba2e1c5800eecf1535ed3945cd55ebb1e253b7a29ba09559"

+ 3 - 2
api/pyproject.toml

@@ -163,6 +163,7 @@ redis = { version = "~5.0.3", extras = ["hiredis"] }
 replicate = "~0.22.0"
 resend = "~0.7.0"
 safetensors = "~0.4.3"
+scikit-learn = "^1.5.1"
 sentry-sdk = { version = "~1.44.1", extras = ["flask"] }
 sqlalchemy = "~2.0.29"
 tencentcloud-sdk-python-hunyuan = "~3.0.1158"
@@ -175,7 +176,7 @@ werkzeug = "~3.0.1"
 xinference-client = "0.9.4"
 yarl = "~1.9.4"
 zhipuai = "1.0.7"
-
+rank-bm25 = "~0.2.2"
 ############################################################
 # Tool dependencies required by tool implementations
 ############################################################
@@ -200,7 +201,7 @@ cloudscraper = "1.2.71"
 ############################################################
 
 [tool.poetry.group.vdb.dependencies]
-chromadb = "~0.5.1"
+chromadb = "0.5.1"
 oracledb = "~2.2.1"
 pgvecto-rs = "0.1.4"
 pgvector = "0.2.5"

+ 5 - 3
api/services/hit_testing_service.py

@@ -38,14 +38,16 @@ class HitTestingService:
         if not retrieval_model:
             retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
 
-        all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
+        all_documents = RetrievalService.retrieve(retrival_method=retrieval_model.get('search_method', 'semantic_search'),
                                                   dataset_id=dataset.id,
                                                   query=cls.escape_query_for_search(query),
-                                                  top_k=retrieval_model['top_k'],
+                                                  top_k=retrieval_model.get('top_k', 2),
                                                   score_threshold=retrieval_model['score_threshold']
                                                   if retrieval_model['score_threshold_enabled'] else None,
                                                   reranking_model=retrieval_model['reranking_model']
-                                                  if retrieval_model['reranking_enable'] else None
+                                                  if retrieval_model['reranking_enable'] else None,
+                                                  reranking_mode=retrieval_model.get('reranking_mode', None),
+                                                  weights=retrieval_model.get('weights', None),
                                                   )
 
         end = time.perf_counter()