Browse Source

Fix/new RAG bugs (#2547)

Co-authored-by: jyong <jyong@dify.ai>
Jyong 1 year ago
parent
commit
4be3087642

+ 1 - 1
api/core/indexing_runner.py

@@ -365,7 +365,7 @@ class IndexingRunner:
                 notion_info={
                     "notion_workspace_id": data_source_info['notion_workspace_id'],
                     "notion_obj_id": data_source_info['notion_page_id'],
-                    "notion_page_type": data_source_info['notion_page_type'],
+                    "notion_page_type": data_source_info['type'],
                     "document": dataset_document
                 },
                 document_model=dataset_document.doc_form

+ 6 - 2
api/core/rag/datasource/retrieval_service.py

@@ -2,7 +2,6 @@ import threading
 from typing import Optional
 
 from flask import Flask, current_app
-from flask_login import current_user
 
 from core.rag.data_post_processor.data_post_processor import DataPostProcessor
 from core.rag.datasource.keyword.keyword_factory import Keyword
@@ -27,6 +26,11 @@ class RetrievalService:
     @classmethod
     def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
                  top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None):
+        dataset = db.session.query(Dataset).filter(
+            Dataset.id == dataset_id
+        ).first()
+        if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
+            return []
         all_documents = []
         threads = []
         # retrieval_model source with keyword
@@ -73,7 +77,7 @@ class RetrievalService:
             thread.join()
 
         if retrival_method == 'hybrid_search':
-            data_post_processor = DataPostProcessor(str(current_user.current_tenant_id), reranking_model, False)
+            data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
             all_documents = data_post_processor.invoke(
                 query=query,
                 documents=all_documents,

+ 1 - 1
api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py

@@ -171,7 +171,7 @@ class DatasetMultiRetrieverTool(BaseTool):
 
             if dataset.indexing_technique == "economy":
                 # use keyword table query
-                documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
+                documents = RetrievalService.retrieve(retrival_method='keyword_search',
                                                       dataset_id=dataset.id,
                                                       query=query,
                                                       top_k=self.top_k

+ 1 - 1
api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py

@@ -69,7 +69,7 @@ class DatasetRetrieverTool(BaseTool):
         retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
         if dataset.indexing_technique == "economy":
             # use keyword table query
-            documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
+            documents = RetrievalService.retrieve(retrival_method='keyword_search',
                                                   dataset_id=dataset.id,
                                                   query=query,
                                                   top_k=self.top_k

+ 0 - 1
api/tasks/clean_dataset_task.py

@@ -40,7 +40,6 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
             indexing_technique=indexing_technique,
             index_struct=index_struct,
             collection_binding_id=collection_binding_id,
-            doc_form=doc_form
         )
         documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all()
         segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()