Quellcode durchsuchen

Retrieval Service efficiency optimization (#13543)

Charlie.Wei vor 2 Monaten
Ursprung
Commit
222df44d21
2 geänderte Dateien mit 170 neuen und 124 gelöschten Zeilen
  1. 6 0
      api/configs/middleware/__init__.py
  2. 164 124
      api/core/rag/datasource/retrieval_service.py

+ 6 - 0
api/configs/middleware/__init__.py

@@ -1,3 +1,4 @@
+import os
 from typing import Any, Literal, Optional
 from urllib.parse import quote_plus
 
@@ -166,6 +167,11 @@ class DatabaseConfig(BaseSettings):
         default=False,
     )
 
+    RETRIEVAL_SERVICE_WORKER: NonNegativeInt = Field(
+        description="If True, enables the retrieval service worker.",
+        default=os.cpu_count(),
+    )
+
     @computed_field
     def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
         return {

+ 164 - 124
api/core/rag/datasource/retrieval_service.py

@@ -1,9 +1,11 @@
+import concurrent.futures
 import json
-import threading
 from typing import Optional
 
 from flask import Flask, current_app
+from sqlalchemy.orm import load_only
 
+from configs import dify_config
 from core.rag.data_post_processor.data_post_processor import DataPostProcessor
 from core.rag.datasource.keyword.keyword_factory import Keyword
 from core.rag.datasource.vdb.vector_factory import Vector
@@ -27,6 +29,7 @@ default_retrieval_model = {
 
 
 class RetrievalService:
+    # Cache precompiled regular expressions to avoid repeated compilation
     @classmethod
     def retrieve(
         cls,
@@ -41,74 +44,62 @@ class RetrievalService:
     ):
         if not query:
             return []
-        dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
-        if not dataset:
-            return []
-
+        dataset = cls._get_dataset(dataset_id)
         if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
             return []
+
         all_documents: list[Document] = []
-        threads: list[threading.Thread] = []
         exceptions: list[str] = []
-        # retrieval_model source with keyword
-        if retrieval_method == "keyword_search":
-            keyword_thread = threading.Thread(
-                target=RetrievalService.keyword_search,
-                kwargs={
-                    "flask_app": current_app._get_current_object(),  # type: ignore
-                    "dataset_id": dataset_id,
-                    "query": query,
-                    "top_k": top_k,
-                    "all_documents": all_documents,
-                    "exceptions": exceptions,
-                },
-            )
-            threads.append(keyword_thread)
-            keyword_thread.start()
-        # retrieval_model source with semantic
-        if RetrievalMethod.is_support_semantic_search(retrieval_method):
-            embedding_thread = threading.Thread(
-                target=RetrievalService.embedding_search,
-                kwargs={
-                    "flask_app": current_app._get_current_object(),  # type: ignore
-                    "dataset_id": dataset_id,
-                    "query": query,
-                    "top_k": top_k,
-                    "score_threshold": score_threshold,
-                    "reranking_model": reranking_model,
-                    "all_documents": all_documents,
-                    "retrieval_method": retrieval_method,
-                    "exceptions": exceptions,
-                },
-            )
-            threads.append(embedding_thread)
-            embedding_thread.start()
-
-        # retrieval source with full text
-        if RetrievalMethod.is_support_fulltext_search(retrieval_method):
-            full_text_index_thread = threading.Thread(
-                target=RetrievalService.full_text_index_search,
-                kwargs={
-                    "flask_app": current_app._get_current_object(),  # type: ignore
-                    "dataset_id": dataset_id,
-                    "query": query,
-                    "retrieval_method": retrieval_method,
-                    "score_threshold": score_threshold,
-                    "top_k": top_k,
-                    "reranking_model": reranking_model,
-                    "all_documents": all_documents,
-                    "exceptions": exceptions,
-                },
-            )
-            threads.append(full_text_index_thread)
-            full_text_index_thread.start()
 
-        for thread in threads:
-            thread.join()
+        # Optimize multithreading with thread pools
+        with concurrent.futures.ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_WORKER) as executor:  # type: ignore
+            futures = []
+            if retrieval_method == "keyword_search":
+                futures.append(
+                    executor.submit(
+                        cls.keyword_search,
+                        flask_app=current_app._get_current_object(),  # type: ignore
+                        dataset_id=dataset_id,
+                        query=query,
+                        top_k=top_k,
+                        all_documents=all_documents,
+                        exceptions=exceptions,
+                    )
+                )
+            if RetrievalMethod.is_support_semantic_search(retrieval_method):
+                futures.append(
+                    executor.submit(
+                        cls.embedding_search,
+                        flask_app=current_app._get_current_object(),  # type: ignore
+                        dataset_id=dataset_id,
+                        query=query,
+                        top_k=top_k,
+                        score_threshold=score_threshold,
+                        reranking_model=reranking_model,
+                        all_documents=all_documents,
+                        retrieval_method=retrieval_method,
+                        exceptions=exceptions,
+                    )
+                )
+            if RetrievalMethod.is_support_fulltext_search(retrieval_method):
+                futures.append(
+                    executor.submit(
+                        cls.full_text_index_search,
+                        flask_app=current_app._get_current_object(),  # type: ignore
+                        dataset_id=dataset_id,
+                        query=query,
+                        top_k=top_k,
+                        score_threshold=score_threshold,
+                        reranking_model=reranking_model,
+                        all_documents=all_documents,
+                        retrieval_method=retrieval_method,
+                        exceptions=exceptions,
+                    )
+                )
+            concurrent.futures.wait(futures, timeout=30, return_when=concurrent.futures.ALL_COMPLETED)
 
         if exceptions:
-            exception_message = ";\n".join(exceptions)
-            raise ValueError(exception_message)
+            raise ValueError(";\n".join(exceptions))
 
         if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
             data_post_processor = DataPostProcessor(
@@ -133,18 +124,21 @@ class RetrievalService:
         )
         return all_documents
 
+    @classmethod
+    def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]:
+        return db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+
     @classmethod
     def keyword_search(
         cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list
     ):
         with flask_app.app_context():
             try:
-                dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+                dataset = cls._get_dataset(dataset_id)
                 if not dataset:
                     raise ValueError("dataset not found")
 
                 keyword = Keyword(dataset=dataset)
-
                 documents = keyword.search(cls.escape_query_for_search(query), top_k=top_k)
                 all_documents.extend(documents)
             except Exception as e:
@@ -165,12 +159,11 @@ class RetrievalService:
     ):
         with flask_app.app_context():
             try:
-                dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+                dataset = cls._get_dataset(dataset_id)
                 if not dataset:
                     raise ValueError("dataset not found")
 
                 vector = Vector(dataset=dataset)
-
                 documents = vector.search_by_vector(
                     query,
                     search_type="similarity_score_threshold",
@@ -187,7 +180,7 @@ class RetrievalService:
                         and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value
                     ):
                         data_post_processor = DataPostProcessor(
-                            str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False
+                            str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False
                         )
                         all_documents.extend(
                             data_post_processor.invoke(
@@ -217,13 +210,11 @@ class RetrievalService:
     ):
         with flask_app.app_context():
             try:
-                dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+                dataset = cls._get_dataset(dataset_id)
                 if not dataset:
                     raise ValueError("dataset not found")
 
-                vector_processor = Vector(
-                    dataset=dataset,
-                )
+                vector_processor = Vector(dataset=dataset)
 
                 documents = vector_processor.search_by_full_text(cls.escape_query_for_search(query), top_k=top_k)
                 if documents:
@@ -234,7 +225,7 @@ class RetrievalService:
                         and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value
                     ):
                         data_post_processor = DataPostProcessor(
-                            str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False
+                            str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False
                         )
                         all_documents.extend(
                             data_post_processor.invoke(
@@ -253,64 +244,105 @@ class RetrievalService:
     def escape_query_for_search(query: str) -> str:
         return json.dumps(query).strip('"')
 
-    @staticmethod
-    def format_retrieval_documents(documents: list[Document]) -> list[RetrievalSegments]:
-        records = []
-        include_segment_ids = []
-        segment_child_map = {}
-        for document in documents:
-            document_id = document.metadata.get("document_id")
-            dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
-            if dataset_document:
+    @classmethod
+    def format_retrieval_documents(cls, documents: list[Document]) -> list[RetrievalSegments]:
+        """Format retrieval documents with optimized batch processing"""
+        if not documents:
+            return []
+
+        try:
+            # Collect document IDs
+            document_ids = {doc.metadata.get("document_id") for doc in documents if "document_id" in doc.metadata}
+            if not document_ids:
+                return []
+
+            # Batch query dataset documents
+            dataset_documents = {
+                doc.id: doc
+                for doc in db.session.query(DatasetDocument)
+                .filter(DatasetDocument.id.in_(document_ids))
+                .options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id))
+                .all()
+            }
+
+            records = []
+            include_segment_ids = set()
+            segment_child_map = {}
+
+            # Process documents
+            for document in documents:
+                document_id = document.metadata.get("document_id")
+                if document_id not in dataset_documents:
+                    continue
+
+                dataset_document = dataset_documents[document_id]
+
                 if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
+                    # Handle parent-child documents
                     child_index_node_id = document.metadata.get("doc_id")
-                    result = (
-                        db.session.query(ChildChunk, DocumentSegment)
-                        .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
+
+                    child_chunk = (
+                        db.session.query(ChildChunk).filter(ChildChunk.index_node_id == child_index_node_id).first()
+                    )
+
+                    if not child_chunk:
+                        continue
+
+                    segment = (
+                        db.session.query(DocumentSegment)
                         .filter(
-                            ChildChunk.index_node_id == child_index_node_id,
                             DocumentSegment.dataset_id == dataset_document.dataset_id,
                             DocumentSegment.enabled == True,
                             DocumentSegment.status == "completed",
+                            DocumentSegment.id == child_chunk.segment_id,
+                        )
+                        .options(
+                            load_only(
+                                DocumentSegment.id,
+                                DocumentSegment.content,
+                                DocumentSegment.answer,
+                                DocumentSegment.doc_metadata,
+                            )
                         )
                         .first()
                     )
-                    if result:
-                        child_chunk, segment = result
-                        if not segment:
-                            continue
-                        if segment.id not in include_segment_ids:
-                            include_segment_ids.append(segment.id)
-                            child_chunk_detail = {
-                                "id": child_chunk.id,
-                                "content": child_chunk.content,
-                                "position": child_chunk.position,
-                                "score": document.metadata.get("score", 0.0),
-                            }
-                            map_detail = {
-                                "max_score": document.metadata.get("score", 0.0),
-                                "child_chunks": [child_chunk_detail],
-                            }
-                            segment_child_map[segment.id] = map_detail
-                            record = {
-                                "segment": segment,
-                            }
-                            records.append(record)
-                        else:
-                            child_chunk_detail = {
-                                "id": child_chunk.id,
-                                "content": child_chunk.content,
-                                "position": child_chunk.position,
-                                "score": document.metadata.get("score", 0.0),
-                            }
-                            segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
-                            segment_child_map[segment.id]["max_score"] = max(
-                                segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
-                            )
-                    else:
+
+                    if not segment:
                         continue
+
+                    if segment.id not in include_segment_ids:
+                        include_segment_ids.add(segment.id)
+                        child_chunk_detail = {
+                            "id": child_chunk.id,
+                            "content": child_chunk.content,
+                            "position": child_chunk.position,
+                            "score": document.metadata.get("score", 0.0),
+                        }
+                        map_detail = {
+                            "max_score": document.metadata.get("score", 0.0),
+                            "child_chunks": [child_chunk_detail],
+                        }
+                        segment_child_map[segment.id] = map_detail
+                        record = {
+                            "segment": segment,
+                        }
+                        records.append(record)
+                    else:
+                        child_chunk_detail = {
+                            "id": child_chunk.id,
+                            "content": child_chunk.content,
+                            "position": child_chunk.position,
+                            "score": document.metadata.get("score", 0.0),
+                        }
+                        segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
+                        segment_child_map[segment.id]["max_score"] = max(
+                            segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
+                        )
                 else:
-                    index_node_id = document.metadata["doc_id"]
+                    # Handle normal documents
+                    index_node_id = document.metadata.get("doc_id")
+                    if not index_node_id:
+                        continue
 
                     segment = (
                         db.session.query(DocumentSegment)
@@ -325,16 +357,24 @@ class RetrievalService:
 
                     if not segment:
                         continue
-                    include_segment_ids.append(segment.id)
+
+                    include_segment_ids.add(segment.id)
                     record = {
                         "segment": segment,
-                        "score": document.metadata.get("score", None),
+                        "score": document.metadata.get("score"),  # type: ignore
+                        "segment_metadata": segment.doc_metadata,
                     }
-
                     records.append(record)
+
+            # Add child chunks information to records
             for record in records:
                 if record["segment"].id in segment_child_map:
-                    record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None)
+                    record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks")  # type: ignore
                     record["score"] = segment_child_map[record["segment"].id]["max_score"]
 
-        return [RetrievalSegments(**record) for record in records]
+            return [RetrievalSegments(**record) for record in records]
+        except Exception as e:
+            db.session.rollback()
+            raise e
+        finally:
+            db.session.close()