Explorar el Código

fix score_threshold_enabled name (#1626)

Co-authored-by: jyong <jyong@dify.ai>
Jyong hace 1 año
padre
commit
74b2260ba6

+ 4 - 4
api/core/orchestrator_rule_parser.py

@@ -40,7 +40,7 @@ default_retrieval_model = {
         'reranking_model_name': ''
     },
     'top_k': 2,
-    'score_threshold_enable': False
+    'score_threshold_enabled': False
 }
 
 class OrchestratorRuleParser:
@@ -220,8 +220,8 @@ class OrchestratorRuleParser:
                 # top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens)
 
                 score_threshold = None
-                score_threshold_enable = retrieval_model_config.get("score_threshold_enable")
-                if score_threshold_enable:
+                score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
+                if score_threshold_enabled:
                     score_threshold = retrieval_model_config.get("score_threshold")
 
                 tool = DatasetRetrieverTool.from_dataset(
@@ -239,7 +239,7 @@ class OrchestratorRuleParser:
                 dataset_ids=dataset_ids,
                 tenant_id=kwargs['tenant_id'],
                 top_k=dataset_configs.get('top_k', 2),
-                score_threshold=dataset_configs.get('score_threshold', 0.5) if dataset_configs.get('score_threshold_enable', False) else None,
+                score_threshold=dataset_configs.get('score_threshold', 0.5) if dataset_configs.get('score_threshold_enabled', False) else None,
                 callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
                 conversation_message_task=conversation_message_task,
                 return_resource=return_resource,

+ 2 - 2
api/core/tool/dataset_multi_retriever_tool.py

@@ -24,7 +24,7 @@ default_retrieval_model = {
         'reranking_model_name': ''
     },
     'top_k': 2,
-    'score_threshold_enable': False
+    'score_threshold_enabled': False
 }
 
 
@@ -216,7 +216,7 @@ class DatasetMultiRetrieverTool(BaseTool):
                                                                       'embeddings': embeddings,
                                                                       'score_threshold': retrieval_model[
                                                                           'score_threshold'] if retrieval_model[
-                                                                          'score_threshold_enable'] else None,
+                                                                          'score_threshold_enabled'] else None,
                                                                       'top_k': self.top_k,
                                                                       'reranking_model': retrieval_model[
                                                                           'reranking_model'] if retrieval_model[

+ 4 - 4
api/core/tool/dataset_retriever_tool.py

@@ -25,7 +25,7 @@ default_retrieval_model = {
         'reranking_model_name': ''
     },
     'top_k': 2,
-    'score_threshold_enable': False
+    'score_threshold_enabled': False
 }
 
 
@@ -110,7 +110,7 @@ class DatasetRetrieverTool(BaseTool):
                         'query': query,
                         'top_k': self.top_k,
                         'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
-                            'score_threshold_enable'] else None,
+                            'score_threshold_enabled'] else None,
                         'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
                             'reranking_enable'] else None,
                         'all_documents': documents,
@@ -129,7 +129,7 @@ class DatasetRetrieverTool(BaseTool):
                         'search_method': retrieval_model['search_method'],
                         'embeddings': embeddings,
                         'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
-                            'score_threshold_enable'] else None,
+                            'score_threshold_enabled'] else None,
                         'top_k': self.top_k,
                         'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
                             'reranking_enable'] else None,
@@ -148,7 +148,7 @@ class DatasetRetrieverTool(BaseTool):
                         model_name=retrieval_model['reranking_model']['reranking_model_name']
                     )
                     documents = hybrid_rerank.rerank(query, documents,
-                                                     retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
+                                                     retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
                                                      self.top_k)
             else:
                 documents = []

+ 1 - 1
api/fields/dataset_fields.py

@@ -22,7 +22,7 @@ dataset_retrieval_model_fields = {
     'reranking_enable': fields.Boolean,
     'reranking_model': fields.Nested(reranking_model_fields),
     'top_k': fields.Integer,
-    'score_threshold_enable': fields.Boolean,
+    'score_threshold_enabled': fields.Boolean,
     'score_threshold': fields.Float
 }
 

+ 1 - 1
api/models/dataset.py

@@ -104,7 +104,7 @@ class Dataset(db.Model):
                 'reranking_model_name': ''
             },
             'top_k': 2,
-            'score_threshold_enable': False
+            'score_threshold_enabled': False
         }
         return self.retrieval_model if self.retrieval_model else default_retrieval_model
 

+ 2 - 2
api/services/dataset_service.py

@@ -485,7 +485,7 @@ class DocumentService:
                             'reranking_model_name': ''
                         },
                         'top_k': 2,
-                        'score_threshold_enable': False
+                        'score_threshold_enabled': False
                     }
 
                     dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get('retrieval_model') else default_retrieval_model
@@ -769,7 +769,7 @@ class DocumentService:
                         'reranking_model_name': ''
                     },
                     'top_k': 2,
-                    'score_threshold_enable': False
+                    'score_threshold_enabled': False
                 }
                 retrieval_model = default_retrieval_model
         # save dataset

+ 4 - 4
api/services/hit_testing_service.py

@@ -25,7 +25,7 @@ default_retrieval_model = {
         'reranking_model_name': ''
     },
     'top_k': 2,
-    'score_threshold_enable': False
+    'score_threshold_enabled': False
 }
 
 class HitTestingService:
@@ -64,7 +64,7 @@ class HitTestingService:
                 'dataset_id': str(dataset.id),
                 'query': query,
                 'top_k': retrieval_model['top_k'],
-                'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
+                'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
                 'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,
                 'all_documents': all_documents,
                 'search_method': retrieval_model['search_method'],
@@ -81,7 +81,7 @@ class HitTestingService:
                 'query': query,
                 'search_method': retrieval_model['search_method'],
                 'embeddings': embeddings,
-                'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
+                'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
                 'top_k': retrieval_model['top_k'],
                 'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,
                 'all_documents': all_documents
@@ -99,7 +99,7 @@ class HitTestingService:
                 model_name=retrieval_model['reranking_model']['reranking_model_name']
             )
             all_documents = hybrid_rerank.rerank(query, all_documents,
-                                                 retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
+                                                 retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
                                                  retrieval_model['top_k'])
 
         end = time.perf_counter()

+ 1 - 1
api/services/retrieval_service.py

@@ -15,7 +15,7 @@ default_retrieval_model = {
         'reranking_model_name': ''
     },
     'top_k': 2,
-    'score_threshold_enable': False
+    'score_threshold_enabled': False
 }