Ver código fonte

Fix/reranking mode is null (#7012)

Jyong 8 meses atrás
pai
commit
28d4e5b045

+ 1 - 5
api/core/rag/datasource/retrieval_service.py

@@ -28,7 +28,7 @@ class RetrievalService:
     @classmethod
     def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
                  top_k: int, score_threshold: Optional[float] = .0,
-                 reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = None,
+                 reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = 'reranking_model',
                  weights: Optional[dict] = None):
         dataset = db.session.query(Dataset).filter(
             Dataset.id == dataset_id
@@ -36,10 +36,6 @@ class RetrievalService:
         if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
             return []
         all_documents = []
-        keyword_search_documents = []
-        embedding_search_documents = []
-        full_text_search_documents = []
-        hybrid_search_documents = []
         threads = []
         exceptions = []
         # retrieval_model source with keyword

+ 5 - 2
api/core/rag/retrieval/dataset_retrieval.py

@@ -278,6 +278,7 @@ class DatasetRetrieval:
                         query=query,
                         top_k=top_k, score_threshold=score_threshold,
                         reranking_model=reranking_model,
+                        reranking_mode=retrieval_model_config.get('reranking_mode', 'reranking_model'),
                         weights=retrieval_model_config.get('weights', None),
                     )
                 self._on_query(query, [dataset_id], app_id, user_from, user_id)
@@ -431,10 +432,12 @@ class DatasetRetrieval:
                                                           dataset_id=dataset.id,
                                                           query=query,
                                                           top_k=top_k,
-                                                          score_threshold=retrieval_model['score_threshold']
+                                                          score_threshold=retrieval_model.get('score_threshold', .0)
                                                           if retrieval_model['score_threshold_enabled'] else None,
-                                                          reranking_model=retrieval_model['reranking_model']
+                                                          reranking_model=retrieval_model.get('reranking_model', None)
                                                           if retrieval_model['reranking_enable'] else None,
+                                                          reranking_mode=retrieval_model.get('reranking_mode')
+                                                          if retrieval_model.get('reranking_mode') else 'reranking_model',
                                                           weights=retrieval_model.get('weights', None),
                                                           )
 

+ 4 - 2
api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py

@@ -177,10 +177,12 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
                                                           dataset_id=dataset.id,
                                                           query=query,
                                                           top_k=self.top_k,
-                                                          score_threshold=retrieval_model['score_threshold']
+                                                          score_threshold=retrieval_model.get('score_threshold', .0)
                                                           if retrieval_model['score_threshold_enabled'] else None,
-                                                          reranking_model=retrieval_model['reranking_model']
+                                                          reranking_model=retrieval_model.get('reranking_model', None)
                                                           if retrieval_model['reranking_enable'] else None,
+                                                          reranking_mode=retrieval_model.get('reranking_mode')
+                                                          if retrieval_model.get('reranking_mode') else 'reranking_model',
                                                           weights=retrieval_model.get('weights', None),
                                                           )
 

+ 6 - 4
api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py

@@ -14,6 +14,7 @@ default_retrieval_model = {
         'reranking_provider_name': '',
         'reranking_model_name': ''
     },
+    'reranking_mode': 'reranking_model',
     'top_k': 2,
     'score_threshold_enabled': False
 }
@@ -71,14 +72,15 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
         else:
             if self.top_k > 0:
                 # retrieval source
-                documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
+                documents = RetrievalService.retrieve(retrival_method=retrieval_model.get('search_method', 'semantic_search'),
                                                       dataset_id=dataset.id,
                                                       query=query,
                                                       top_k=self.top_k,
-                                                      score_threshold=retrieval_model['score_threshold']
+                                                      score_threshold=retrieval_model.get('score_threshold', .0)
                                                       if retrieval_model['score_threshold_enabled'] else None,
-                                                      reranking_model=retrieval_model['reranking_model']
-                                                      if retrieval_model['reranking_enable'] else None,
+                                                      reranking_model=retrieval_model.get('reranking_model', None),
+                                                      reranking_mode=retrieval_model.get('reranking_mode')
+                                                      if retrieval_model.get('reranking_mode') else 'reranking_model',
                                                       weights=retrieval_model.get('weights', None),
                                                       )
             else:

+ 4 - 4
api/services/hit_testing_service.py

@@ -42,11 +42,11 @@ class HitTestingService:
                                                   dataset_id=dataset.id,
                                                   query=cls.escape_query_for_search(query),
                                                   top_k=retrieval_model.get('top_k', 2),
-                                                  score_threshold=retrieval_model['score_threshold']
+                                                  score_threshold=retrieval_model.get('score_threshold', .0)
                                                   if retrieval_model['score_threshold_enabled'] else None,
-                                                  reranking_model=retrieval_model['reranking_model']
-                                                  if retrieval_model['reranking_enable'] else None,
-                                                  reranking_mode=retrieval_model.get('reranking_mode', None),
+                                                  reranking_model=retrieval_model.get('reranking_model', None),
+                                                  reranking_mode=retrieval_model.get('reranking_mode')
+                                                  if retrieval_model.get('reranking_mode') else 'reranking_model',
                                                   weights=retrieval_model.get('weights', None),
                                                   )