ソースを参照

Optimize knowledge retrieval performance by batching dataset quries. (#4917)

JasonVV 10 ヶ月 前
コミット
7749b71fff

+ 4 - 1
api/core/rag/retrieval/dataset_retrieval.py

@@ -329,6 +329,7 @@ class DatasetRetrieval:
         """
         if not query:
             return
+        dataset_queries = []
         for dataset_id in dataset_ids:
             dataset_query = DatasetQuery(
                 dataset_id=dataset_id,
@@ -338,7 +339,9 @@ class DatasetRetrieval:
                 created_by_role=user_from,
                 created_by=user_id
             )
-            db.session.add(dataset_query)
+            dataset_queries.append(dataset_query)
+        if dataset_queries:
+            db.session.add_all(dataset_queries)
         db.session.commit()
 
     def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):

+ 23 - 18
api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py

@@ -1,5 +1,7 @@
 from typing import Any, cast
 
+from sqlalchemy import func
+
 from core.app.app_config.entities import DatasetRetrieveConfigEntity
 from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
 from core.entities.agent_entities import PlanningStrategy
@@ -73,30 +75,33 @@ class KnowledgeRetrievalNode(BaseNode):
 
     def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[
         dict[str, Any]]:
-        """
-        A dataset tool is a tool that can be used to retrieve information from a dataset
-        :param node_data: node data
-        :param query: query
-        """
-        tools = []
         available_datasets = []
         dataset_ids = node_data.dataset_ids
-        for dataset_id in dataset_ids:
-            # get dataset from dataset id
-            dataset = db.session.query(Dataset).filter(
-                Dataset.tenant_id == self.tenant_id,
-                Dataset.id == dataset_id
-            ).first()
 
-            # pass if dataset is not available
-            if not dataset:
-                continue
+        # Subquery: Count the number of available documents for each dataset
+        subquery = db.session.query(
+            Document.dataset_id,
+            func.count(Document.id).label('available_document_count')
+        ).filter(
+            Document.indexing_status == 'completed',
+            Document.enabled == True,
+            Document.archived == False,
+            Document.dataset_id.in_(dataset_ids)
+        ).group_by(Document.dataset_id).having(
+            func.count(Document.id) > 0
+        ).subquery()
+
+        results = db.session.query(Dataset).join(
+            subquery, Dataset.id == subquery.c.dataset_id
+        ).filter(
+            Dataset.tenant_id == self.tenant_id,
+            Dataset.id.in_(dataset_ids)
+        ).all()
 
+        for dataset in results:
             # pass if dataset is not available
-            if (dataset and dataset.available_document_count == 0
-                    and dataset.available_document_count == 0):
+            if not dataset:
                 continue
-
             available_datasets.append(dataset)
         all_documents = []
         dataset_retrieval = DatasetRetrieval()