Переглянути джерело

fix parent-child retrival count (#15119)

Jyong 1 місяць тому
батько
коміт
435564f0f2

+ 24 - 8
api/core/callback_handler/index_tool_callback_handler.py

@@ -1,9 +1,11 @@
 from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
 from core.app.entities.app_invoke_entities import InvokeFrom
 from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
+from core.rag.index_processor.constant.index_type import IndexType
 from core.rag.models.document import Document
 from extensions.ext_database import db
-from models.dataset import DatasetQuery, DocumentSegment
+from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
+from models.dataset import Document as DatasetDocument
 from models.model import DatasetRetrieverResource
 
 
@@ -41,15 +43,29 @@ class DatasetIndexToolCallbackHandler:
         """Handle tool end."""
         for document in documents:
             if document.metadata is not None:
-                query = db.session.query(DocumentSegment).filter(
-                    DocumentSegment.index_node_id == document.metadata["doc_id"]
-                )
+                dataset_document = DatasetDocument.query.filter(
+                    DatasetDocument.id == document.metadata["document_id"]
+                ).first()
+                if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
+                    child_chunk = ChildChunk.query.filter(
+                        ChildChunk.index_node_id == document.metadata["doc_id"],
+                        ChildChunk.dataset_id == dataset_document.dataset_id,
+                        ChildChunk.document_id == dataset_document.id,
+                    ).first()
+                    if child_chunk:
+                        segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update(
+                            {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
+                        )
+                else:
+                    query = db.session.query(DocumentSegment).filter(
+                        DocumentSegment.index_node_id == document.metadata["doc_id"]
+                    )
 
-                if "dataset_id" in document.metadata:
-                    query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
+                    if "dataset_id" in document.metadata:
+                        query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
 
-                # add hit count to document segment
-                query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
+                    # add hit count to document segment
+                    query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
 
                 db.session.commit()
 

+ 25 - 9
api/core/rag/retrieval/dataset_retrieval.py

@@ -21,6 +21,7 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor
 from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
 from core.rag.datasource.retrieval_service import RetrievalService
 from core.rag.entities.context_entities import DocumentContext
+from core.rag.index_processor.constant.index_type import IndexType
 from core.rag.models.document import Document
 from core.rag.rerank.rerank_type import RerankMode
 from core.rag.retrieval.retrieval_methods import RetrievalMethod
@@ -28,7 +29,7 @@ from core.rag.retrieval.router.multi_dataset_function_call_router import Functio
 from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
 from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
 from extensions.ext_database import db
-from models.dataset import Dataset, DatasetQuery, DocumentSegment
+from models.dataset import ChildChunk, Dataset, DatasetQuery, DocumentSegment
 from models.dataset import Document as DatasetDocument
 from services.external_knowledge_service import ExternalDatasetService
 
@@ -429,16 +430,31 @@ class DatasetRetrieval:
         dify_documents = [document for document in documents if document.provider == "dify"]
         for document in dify_documents:
             if document.metadata is not None:
-                query = db.session.query(DocumentSegment).filter(
-                    DocumentSegment.index_node_id == document.metadata["doc_id"]
-                )
+                dataset_document = DatasetDocument.query.filter(
+                    DatasetDocument.id == document.metadata["document_id"]
+                ).first()
+                if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
+                    child_chunk = ChildChunk.query.filter(
+                        ChildChunk.index_node_id == document.metadata["doc_id"],
+                        ChildChunk.dataset_id == dataset_document.dataset_id,
+                        ChildChunk.document_id == dataset_document.id,
+                    ).first()
+                    if child_chunk:
+                        segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update(
+                            {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
+                        )
+                        db.session.commit()
+                else:
+                    query = db.session.query(DocumentSegment).filter(
+                        DocumentSegment.index_node_id == document.metadata["doc_id"]
+                    )
 
-                # if 'dataset_id' in document.metadata:
-                if "dataset_id" in document.metadata:
-                    query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
+                    # if 'dataset_id' in document.metadata:
+                    if "dataset_id" in document.metadata:
+                        query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
 
-                # add hit count to document segment
-                query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
+                    # add hit count to document segment
+                    query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
 
                 db.session.commit()