Browse Source

Feat/support parent child chunk (#12092)

Jyong 4 months ago
parent
commit
9231fdbf4c
54 changed files with 2576 additions and 806 deletions
  1. 1 1
      api/controllers/console/datasets/data_source.py
  2. 14 1
      api/controllers/console/datasets/datasets.py
  3. 99 79
      api/controllers/console/datasets/datasets_document.py
  4. 317 82
      api/controllers/console/datasets/datasets_segments.py
  5. 12 0
      api/controllers/console/datasets/error.py
  6. 2 1
      api/controllers/service_api/dataset/segment.py
  7. 19 0
      api/core/entities/knowledge_entities.py
  8. 68 212
      api/core/indexing_runner.py
  9. 89 1
      api/core/rag/datasource/retrieval_service.py
  10. 43 2
      api/core/rag/docstore/dataset_docstore.py
  11. 23 0
      api/core/rag/embedding/retrieval.py
  12. 1 6
      api/core/rag/extractor/extract_processor.py
  13. 3 1
      api/core/rag/extractor/word_extractor.py
  14. 2 3
      api/core/rag/index_processor/constant/index_type.py
  15. 14 11
      api/core/rag/index_processor/index_processor_base.py
  16. 5 2
      api/core/rag/index_processor/index_processor_factory.py
  17. 23 6
      api/core/rag/index_processor/processor/paragraph_index_processor.py
  18. 189 0
      api/core/rag/index_processor/processor/parent_child_index_processor.py
  19. 41 22
      api/core/rag/index_processor/processor/qa_index_processor.py
  20. 15 0
      api/core/rag/models/document.py
  21. 9 23
      api/core/rag/retrieval/dataset_retrieval.py
  22. 8 26
      api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
  23. 1 0
      api/fields/dataset_fields.py
  24. 1 0
      api/fields/document_fields.py
  25. 8 0
      api/fields/hit_testing_fields.py
  26. 14 0
      api/fields/segment_fields.py
  27. 55 0
      api/migrations/versions/2024_11_22_0701-e19037032219_parent_child_index.py
  28. 47 0
      api/migrations/versions/2024_12_25_1137-923752d42eb6_add_auto_disabled_dataset_logs.py
  29. 84 3
      api/models/dataset.py
  30. 35 1
      api/schedule/clean_unused_datasets_task.py
  31. 66 0
      api/schedule/mail_clean_document_notify_task.py
  32. 522 218
      api/services/dataset_service.py
  33. 111 1
      api/services/entities/knowledge_entities/knowledge_entities.py
  34. 9 0
      api/services/errors/chunk.py
  35. 5 32
      api/services/hit_testing_service.py
  36. 171 23
      api/services/vector_service.py
  37. 25 3
      api/tasks/add_document_to_index_task.py
  38. 75 0
      api/tasks/batch_clean_document_task.py
  39. 2 3
      api/tasks/batch_create_segment_to_index_task.py
  40. 1 1
      api/tasks/clean_dataset_task.py
  41. 1 1
      api/tasks/clean_document_task.py
  42. 1 1
      api/tasks/clean_notion_document_task.py
  43. 19 3
      api/tasks/deal_dataset_vector_index_task.py
  44. 6 16
      api/tasks/delete_segment_from_index_task.py
  45. 76 0
      api/tasks/disable_segments_from_index_task.py
  46. 1 1
      api/tasks/document_indexing_sync_task.py
  47. 1 1
      api/tasks/document_indexing_update_task.py
  48. 3 3
      api/tasks/duplicate_document_indexing_task.py
  49. 18 1
      api/tasks/enable_segment_to_index_task.py
  50. 108 0
      api/tasks/enable_segments_to_index_task.py
  51. 1 1
      api/tasks/remove_document_from_index_task.py
  52. 7 7
      api/tasks/retry_document_indexing_task.py
  53. 7 7
      api/tasks/sync_website_document_indexing_task.py
  54. 98 0
      api/templates/clean_document_job_mail_template-US.html

+ 1 - 1
api/controllers/console/datasets/data_source.py

@@ -218,7 +218,7 @@ class DataSourceNotionApi(Resource):
             args["doc_form"],
             args["doc_language"],
         )
-        return response, 200
+        return response.model_dump(), 200
 
 
 class DataSourceNotionDatasetSyncApi(Resource):

+ 14 - 1
api/controllers/console/datasets/datasets.py

@@ -464,7 +464,7 @@ class DatasetIndexingEstimateApi(Resource):
         except Exception as e:
             raise IndexingEstimateError(str(e))
 
-        return response, 200
+        return response.model_dump(), 200
 
 
 class DatasetRelatedAppListApi(Resource):
@@ -733,6 +733,18 @@ class DatasetPermissionUserListApi(Resource):
         }, 200
 
 
+class DatasetAutoDisableLogApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self, dataset_id):
+        dataset_id_str = str(dataset_id)
+        dataset = DatasetService.get_dataset(dataset_id_str)
+        if dataset is None:
+            raise NotFound("Dataset not found.")
+        return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200
+
+
 api.add_resource(DatasetListApi, "/datasets")
 api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")
 api.add_resource(DatasetUseCheckApi, "/datasets/<uuid:dataset_id>/use-check")
@@ -747,3 +759,4 @@ api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info")
 api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting")
 api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/<string:vector_type>")
 api.add_resource(DatasetPermissionUserListApi, "/datasets/<uuid:dataset_id>/permission-part-users")
+api.add_resource(DatasetAutoDisableLogApi, "/datasets/<uuid:dataset_id>/auto-disable-logs")

+ 99 - 79
api/controllers/console/datasets/datasets_document.py

@@ -52,6 +52,7 @@ from fields.document_fields import (
 from libs.login import login_required
 from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
 from services.dataset_service import DatasetService, DocumentService
+from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
 from tasks.add_document_to_index_task import add_document_to_index_task
 from tasks.remove_document_from_index_task import remove_document_from_index_task
 
@@ -255,20 +256,22 @@ class DatasetDocumentListApi(Resource):
         parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
         parser.add_argument("original_document_id", type=str, required=False, location="json")
         parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
+        parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
+
         parser.add_argument(
             "doc_language", type=str, default="English", required=False, nullable=False, location="json"
         )
-        parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
         args = parser.parse_args()
+        knowledge_config = KnowledgeConfig(**args)
 
-        if not dataset.indexing_technique and not args["indexing_technique"]:
+        if not dataset.indexing_technique and not knowledge_config.indexing_technique:
             raise ValueError("indexing_technique is required.")
 
         # validate args
-        DocumentService.document_create_args_validate(args)
+        DocumentService.document_create_args_validate(knowledge_config)
 
         try:
-            documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
+            documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user)
         except ProviderTokenNotInitError as ex:
             raise ProviderNotInitializeError(ex.description)
         except QuotaExceededError:
@@ -278,6 +281,25 @@ class DatasetDocumentListApi(Resource):
 
         return {"documents": documents, "batch": batch}
 
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def delete(self, dataset_id):
+        dataset_id = str(dataset_id)
+        dataset = DatasetService.get_dataset(dataset_id)
+        if dataset is None:
+            raise NotFound("Dataset not found.")
+        # check user's model setting
+        DatasetService.check_dataset_model_setting(dataset)
+
+        try:
+            document_ids = request.args.getlist("document_id")
+            DocumentService.delete_documents(dataset, document_ids)
+        except services.errors.document.DocumentIndexingError:
+            raise DocumentIndexingError("Cannot delete document during indexing.")
+
+        return {"result": "success"}, 204
+
 
 class DatasetInitApi(Resource):
     @setup_required
@@ -313,9 +335,9 @@ class DatasetInitApi(Resource):
         # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
         if not current_user.is_dataset_editor:
             raise Forbidden()
-
-        if args["indexing_technique"] == "high_quality":
-            if args["embedding_model"] is None or args["embedding_model_provider"] is None:
+        knowledge_config = KnowledgeConfig(**args)
+        if knowledge_config.indexing_technique == "high_quality":
+            if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
                 raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
             try:
                 model_manager = ModelManager()
@@ -334,11 +356,11 @@ class DatasetInitApi(Resource):
                 raise ProviderNotInitializeError(ex.description)
 
         # validate args
-        DocumentService.document_create_args_validate(args)
+        DocumentService.document_create_args_validate(knowledge_config)
 
         try:
             dataset, documents, batch = DocumentService.save_document_without_dataset_id(
-                tenant_id=current_user.current_tenant_id, document_data=args, account=current_user
+                tenant_id=current_user.current_tenant_id, knowledge_config=knowledge_config, account=current_user
             )
         except ProviderTokenNotInitError as ex:
             raise ProviderNotInitializeError(ex.description)
@@ -409,7 +431,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
                 except Exception as e:
                     raise IndexingEstimateError(str(e))
 
-        return response
+        return response.model_dump(), 200
 
 
 class DocumentBatchIndexingEstimateApi(DocumentResource):
@@ -422,7 +444,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
         documents = self.get_batch_documents(dataset_id, batch)
         response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}
         if not documents:
-            return response
+            return response, 200
         data_process_rule = documents[0].dataset_process_rule
         data_process_rule_dict = data_process_rule.to_dict()
         info_list = []
@@ -509,7 +531,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
                 raise ProviderNotInitializeError(ex.description)
             except Exception as e:
                 raise IndexingEstimateError(str(e))
-        return response
+        return response.model_dump(), 200
 
 
 class DocumentBatchIndexingStatusApi(DocumentResource):
@@ -582,7 +604,8 @@ class DocumentDetailApi(DocumentResource):
         if metadata == "only":
             response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata}
         elif metadata == "without":
-            process_rules = DatasetService.get_process_rules(dataset_id)
+            dataset_process_rules = DatasetService.get_process_rules(dataset_id)
+            document_process_rules = document.dataset_process_rule.to_dict()
             data_source_info = document.data_source_detail_dict
             response = {
                 "id": document.id,
@@ -590,7 +613,8 @@ class DocumentDetailApi(DocumentResource):
                 "data_source_type": document.data_source_type,
                 "data_source_info": data_source_info,
                 "dataset_process_rule_id": document.dataset_process_rule_id,
-                "dataset_process_rule": process_rules,
+                "dataset_process_rule": dataset_process_rules,
+                "document_process_rule": document_process_rules,
                 "name": document.name,
                 "created_from": document.created_from,
                 "created_by": document.created_by,
@@ -613,7 +637,8 @@ class DocumentDetailApi(DocumentResource):
                 "doc_language": document.doc_language,
             }
         else:
-            process_rules = DatasetService.get_process_rules(dataset_id)
+            dataset_process_rules = DatasetService.get_process_rules(dataset_id)
+            document_process_rules = document.dataset_process_rule.to_dict()
             data_source_info = document.data_source_detail_dict
             response = {
                 "id": document.id,
@@ -621,7 +646,8 @@ class DocumentDetailApi(DocumentResource):
                 "data_source_type": document.data_source_type,
                 "data_source_info": data_source_info,
                 "dataset_process_rule_id": document.dataset_process_rule_id,
-                "dataset_process_rule": process_rules,
+                "dataset_process_rule": dataset_process_rules,
+                "document_process_rule": document_process_rules,
                 "name": document.name,
                 "created_from": document.created_from,
                 "created_by": document.created_by,
@@ -757,9 +783,8 @@ class DocumentStatusApi(DocumentResource):
     @login_required
     @account_initialization_required
     @cloud_edition_billing_resource_check("vector_space")
-    def patch(self, dataset_id, document_id, action):
+    def patch(self, dataset_id, action):
         dataset_id = str(dataset_id)
-        document_id = str(document_id)
         dataset = DatasetService.get_dataset(dataset_id)
         if dataset is None:
             raise NotFound("Dataset not found.")
@@ -774,84 +799,79 @@ class DocumentStatusApi(DocumentResource):
         # check user's permission
         DatasetService.check_dataset_permission(dataset, current_user)
 
-        document = self.get_document(dataset_id, document_id)
+        document_ids = request.args.getlist("document_id")
+        for document_id in document_ids:
+            document = self.get_document(dataset_id, document_id)
 
-        indexing_cache_key = "document_{}_indexing".format(document.id)
-        cache_result = redis_client.get(indexing_cache_key)
-        if cache_result is not None:
-            raise InvalidActionError("Document is being indexed, please try again later")
+            indexing_cache_key = "document_{}_indexing".format(document.id)
+            cache_result = redis_client.get(indexing_cache_key)
+            if cache_result is not None:
+                raise InvalidActionError(f"Document:{document.name} is being indexed, please try again later")
 
-        if action == "enable":
-            if document.enabled:
-                raise InvalidActionError("Document already enabled.")
+            if action == "enable":
+                if document.enabled:
+                    continue
+                document.enabled = True
+                document.disabled_at = None
+                document.disabled_by = None
+                document.updated_at = datetime.now(UTC).replace(tzinfo=None)
+                db.session.commit()
 
-            document.enabled = True
-            document.disabled_at = None
-            document.disabled_by = None
-            document.updated_at = datetime.now(UTC).replace(tzinfo=None)
-            db.session.commit()
+                # Set cache to prevent indexing the same document multiple times
+                redis_client.setex(indexing_cache_key, 600, 1)
 
-            # Set cache to prevent indexing the same document multiple times
-            redis_client.setex(indexing_cache_key, 600, 1)
+                add_document_to_index_task.delay(document_id)
 
-            add_document_to_index_task.delay(document_id)
+            elif action == "disable":
+                if not document.completed_at or document.indexing_status != "completed":
+                    raise InvalidActionError(f"Document: {document.name} is not completed.")
+                if not document.enabled:
+                    continue
 
-            return {"result": "success"}, 200
+                document.enabled = False
+                document.disabled_at = datetime.now(UTC).replace(tzinfo=None)
+                document.disabled_by = current_user.id
+                document.updated_at = datetime.now(UTC).replace(tzinfo=None)
+                db.session.commit()
 
-        elif action == "disable":
-            if not document.completed_at or document.indexing_status != "completed":
-                raise InvalidActionError("Document is not completed.")
-            if not document.enabled:
-                raise InvalidActionError("Document already disabled.")
+                # Set cache to prevent indexing the same document multiple times
+                redis_client.setex(indexing_cache_key, 600, 1)
 
-            document.enabled = False
-            document.disabled_at = datetime.now(UTC).replace(tzinfo=None)
-            document.disabled_by = current_user.id
-            document.updated_at = datetime.now(UTC).replace(tzinfo=None)
-            db.session.commit()
+                remove_document_from_index_task.delay(document_id)
 
-            # Set cache to prevent indexing the same document multiple times
-            redis_client.setex(indexing_cache_key, 600, 1)
+            elif action == "archive":
+                if document.archived:
+                    continue
 
-            remove_document_from_index_task.delay(document_id)
+                document.archived = True
+                document.archived_at = datetime.now(UTC).replace(tzinfo=None)
+                document.archived_by = current_user.id
+                document.updated_at = datetime.now(UTC).replace(tzinfo=None)
+                db.session.commit()
 
-            return {"result": "success"}, 200
+                if document.enabled:
+                    # Set cache to prevent indexing the same document multiple times
+                    redis_client.setex(indexing_cache_key, 600, 1)
 
-        elif action == "archive":
-            if document.archived:
-                raise InvalidActionError("Document already archived.")
+                    remove_document_from_index_task.delay(document_id)
 
-            document.archived = True
-            document.archived_at = datetime.now(UTC).replace(tzinfo=None)
-            document.archived_by = current_user.id
-            document.updated_at = datetime.now(UTC).replace(tzinfo=None)
-            db.session.commit()
+            elif action == "un_archive":
+                if not document.archived:
+                    continue
+                document.archived = False
+                document.archived_at = None
+                document.archived_by = None
+                document.updated_at = datetime.now(UTC).replace(tzinfo=None)
+                db.session.commit()
 
-            if document.enabled:
                 # Set cache to prevent indexing the same document multiple times
                 redis_client.setex(indexing_cache_key, 600, 1)
 
-                remove_document_from_index_task.delay(document_id)
-
-            return {"result": "success"}, 200
-        elif action == "un_archive":
-            if not document.archived:
-                raise InvalidActionError("Document is not archived.")
-
-            document.archived = False
-            document.archived_at = None
-            document.archived_by = None
-            document.updated_at = datetime.now(UTC).replace(tzinfo=None)
-            db.session.commit()
-
-            # Set cache to prevent indexing the same document multiple times
-            redis_client.setex(indexing_cache_key, 600, 1)
+                add_document_to_index_task.delay(document_id)
 
-            add_document_to_index_task.delay(document_id)
-
-            return {"result": "success"}, 200
-        else:
-            raise InvalidActionError()
+            else:
+                raise InvalidActionError()
+        return {"result": "success"}, 200
 
 
 class DocumentPauseApi(DocumentResource):
@@ -1022,7 +1042,7 @@ api.add_resource(
 )
 api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
 api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata")
-api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/status/<string:action>")
+api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/status/<string:action>/batch")
 api.add_resource(DocumentPauseApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause")
 api.add_resource(DocumentRecoverApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume")
 api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry")

+ 317 - 82
api/controllers/console/datasets/datasets_segments.py

@@ -1,5 +1,4 @@
 import uuid
-from datetime import UTC, datetime
 
 import pandas as pd
 from flask import request
@@ -10,7 +9,13 @@ from werkzeug.exceptions import Forbidden, NotFound
 import services
 from controllers.console import api
 from controllers.console.app.error import ProviderNotInitializeError
-from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
+from controllers.console.datasets.error import (
+    ChildChunkDeleteIndexError,
+    ChildChunkIndexingError,
+    InvalidActionError,
+    NoFileUploadedError,
+    TooManyFilesError,
+)
 from controllers.console.wraps import (
     account_initialization_required,
     cloud_edition_billing_knowledge_limit_check,
@@ -20,15 +25,15 @@ from controllers.console.wraps import (
 from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
-from extensions.ext_database import db
 from extensions.ext_redis import redis_client
-from fields.segment_fields import segment_fields
+from fields.segment_fields import child_chunk_fields, segment_fields
 from libs.login import login_required
-from models import DocumentSegment
+from models.dataset import ChildChunk, DocumentSegment
 from services.dataset_service import DatasetService, DocumentService, SegmentService
+from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs
+from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
+from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
 from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
-from tasks.disable_segment_from_index_task import disable_segment_from_index_task
-from tasks.enable_segment_to_index_task import enable_segment_to_index_task
 
 
 class DatasetDocumentSegmentListApi(Resource):
@@ -53,15 +58,16 @@ class DatasetDocumentSegmentListApi(Resource):
             raise NotFound("Document not found.")
 
         parser = reqparse.RequestParser()
-        parser.add_argument("last_id", type=str, default=None, location="args")
         parser.add_argument("limit", type=int, default=20, location="args")
         parser.add_argument("status", type=str, action="append", default=[], location="args")
         parser.add_argument("hit_count_gte", type=int, default=None, location="args")
         parser.add_argument("enabled", type=str, default="all", location="args")
         parser.add_argument("keyword", type=str, default=None, location="args")
+        parser.add_argument("page", type=int, default=1, location="args")
+
         args = parser.parse_args()
 
-        last_id = args["last_id"]
+        page = args["page"]
         limit = min(args["limit"], 100)
         status_list = args["status"]
         hit_count_gte = args["hit_count_gte"]
@@ -69,14 +75,7 @@ class DatasetDocumentSegmentListApi(Resource):
 
         query = DocumentSegment.query.filter(
             DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id
-        )
-
-        if last_id is not None:
-            last_segment = db.session.get(DocumentSegment, str(last_id))
-            if last_segment:
-                query = query.filter(DocumentSegment.position > last_segment.position)
-            else:
-                return {"data": [], "has_more": False, "limit": limit}, 200
+        ).order_by(DocumentSegment.position.asc())
 
         if status_list:
             query = query.filter(DocumentSegment.status.in_(status_list))
@@ -93,21 +92,44 @@ class DatasetDocumentSegmentListApi(Resource):
             elif args["enabled"].lower() == "false":
                 query = query.filter(DocumentSegment.enabled == False)
 
-        total = query.count()
-        segments = query.order_by(DocumentSegment.position).limit(limit + 1).all()
-
-        has_more = False
-        if len(segments) > limit:
-            has_more = True
-            segments = segments[:-1]
+        segments = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
 
-        return {
-            "data": marshal(segments, segment_fields),
-            "doc_form": document.doc_form,
-            "has_more": has_more,
+        response = {
+            "data": marshal(segments.items, segment_fields),
             "limit": limit,
-            "total": total,
-        }, 200
+            "total": segments.total,
+            "total_pages": segments.pages,
+            "page": page,
+        }
+        return response, 200
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def delete(self, dataset_id, document_id):
+        # check dataset
+        dataset_id = str(dataset_id)
+        dataset = DatasetService.get_dataset(dataset_id)
+        if not dataset:
+            raise NotFound("Dataset not found.")
+        # check user's model setting
+        DatasetService.check_dataset_model_setting(dataset)
+        # check document
+        document_id = str(document_id)
+        document = DocumentService.get_document(dataset_id, document_id)
+        if not document:
+            raise NotFound("Document not found.")
+        segment_ids = request.args.getlist("segment_id")
+
+        # The role of the current user in the ta table must be admin or owner
+        if not current_user.is_editor:
+            raise Forbidden()
+        try:
+            DatasetService.check_dataset_permission(dataset, current_user)
+        except services.errors.account.NoPermissionError as e:
+            raise Forbidden(str(e))
+        SegmentService.delete_segments(segment_ids, document, dataset)
+        return {"result": "success"}, 200
 
 
 class DatasetDocumentSegmentApi(Resource):
@@ -115,11 +137,15 @@ class DatasetDocumentSegmentApi(Resource):
     @login_required
     @account_initialization_required
     @cloud_edition_billing_resource_check("vector_space")
-    def patch(self, dataset_id, segment_id, action):
+    def patch(self, dataset_id, document_id, action):
         dataset_id = str(dataset_id)
         dataset = DatasetService.get_dataset(dataset_id)
         if not dataset:
             raise NotFound("Dataset not found.")
+        document_id = str(document_id)
+        document = DocumentService.get_document(dataset_id, document_id)
+        if not document:
+            raise NotFound("Document not found.")
         # check user's model setting
         DatasetService.check_dataset_model_setting(dataset)
         # The role of the current user in the ta table must be admin, owner, or editor
@@ -147,59 +173,17 @@ class DatasetDocumentSegmentApi(Resource):
                 )
             except ProviderTokenNotInitError as ex:
                 raise ProviderNotInitializeError(ex.description)
+        segment_ids = request.args.getlist("segment_id")
 
-        segment = DocumentSegment.query.filter(
-            DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
-        ).first()
-
-        if not segment:
-            raise NotFound("Segment not found.")
-
-        if segment.status != "completed":
-            raise NotFound("Segment is not completed, enable or disable function is not allowed")
-
-        document_indexing_cache_key = "document_{}_indexing".format(segment.document_id)
+        document_indexing_cache_key = "document_{}_indexing".format(document.id)
         cache_result = redis_client.get(document_indexing_cache_key)
         if cache_result is not None:
             raise InvalidActionError("Document is being indexed, please try again later")
-
-        indexing_cache_key = "segment_{}_indexing".format(segment.id)
-        cache_result = redis_client.get(indexing_cache_key)
-        if cache_result is not None:
-            raise InvalidActionError("Segment is being indexed, please try again later")
-
-        if action == "enable":
-            if segment.enabled:
-                raise InvalidActionError("Segment is already enabled.")
-
-            segment.enabled = True
-            segment.disabled_at = None
-            segment.disabled_by = None
-            db.session.commit()
-
-            # Set cache to prevent indexing the same segment multiple times
-            redis_client.setex(indexing_cache_key, 600, 1)
-
-            enable_segment_to_index_task.delay(segment.id)
-
-            return {"result": "success"}, 200
-        elif action == "disable":
-            if not segment.enabled:
-                raise InvalidActionError("Segment is already disabled.")
-
-            segment.enabled = False
-            segment.disabled_at = datetime.now(UTC).replace(tzinfo=None)
-            segment.disabled_by = current_user.id
-            db.session.commit()
-
-            # Set cache to prevent indexing the same segment multiple times
-            redis_client.setex(indexing_cache_key, 600, 1)
-
-            disable_segment_from_index_task.delay(segment.id)
-
-            return {"result": "success"}, 200
-        else:
-            raise InvalidActionError()
+        try:
+            SegmentService.update_segments_status(segment_ids, action, dataset, document)
+        except Exception as e:
+            raise InvalidActionError(str(e))
+        return {"result": "success"}, 200
 
 
 class DatasetDocumentSegmentAddApi(Resource):
@@ -307,9 +291,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
         parser.add_argument("content", type=str, required=True, nullable=False, location="json")
         parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
         parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
+        parser.add_argument(
+            "regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
+        )
         args = parser.parse_args()
         SegmentService.segment_create_args_validate(args, document)
-        segment = SegmentService.update_segment(args, segment, document, dataset)
+        segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset)
         return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
 
     @setup_required
@@ -412,8 +399,248 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
         return {"job_id": job_id, "job_status": cache_result.decode()}, 200
 
 
+class ChildChunkAddApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    @cloud_edition_billing_resource_check("vector_space")
+    @cloud_edition_billing_knowledge_limit_check("add_segment")
+    def post(self, dataset_id, document_id, segment_id):
+        # check dataset
+        dataset_id = str(dataset_id)
+        dataset = DatasetService.get_dataset(dataset_id)
+        if not dataset:
+            raise NotFound("Dataset not found.")
+        # check document
+        document_id = str(document_id)
+        document = DocumentService.get_document(dataset_id, document_id)
+        if not document:
+            raise NotFound("Document not found.")
+        # check segment
+        segment_id = str(segment_id)
+        segment = DocumentSegment.query.filter(
+            DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
+        ).first()
+        if not segment:
+            raise NotFound("Segment not found.")
+        if not current_user.is_editor:
+            raise Forbidden()
+        # check embedding model setting
+        if dataset.indexing_technique == "high_quality":
+            try:
+                model_manager = ModelManager()
+                model_manager.get_model_instance(
+                    tenant_id=current_user.current_tenant_id,
+                    provider=dataset.embedding_model_provider,
+                    model_type=ModelType.TEXT_EMBEDDING,
+                    model=dataset.embedding_model,
+                )
+            except LLMBadRequestError:
+                raise ProviderNotInitializeError(
+                    "No Embedding Model available. Please configure a valid provider "
+                    "in the Settings -> Model Provider."
+                )
+            except ProviderTokenNotInitError as ex:
+                raise ProviderNotInitializeError(ex.description)
+        try:
+            DatasetService.check_dataset_permission(dataset, current_user)
+        except services.errors.account.NoPermissionError as e:
+            raise Forbidden(str(e))
+        # validate args
+        parser = reqparse.RequestParser()
+        parser.add_argument("content", type=str, required=True, nullable=False, location="json")
+        args = parser.parse_args()
+        try:
+            child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset)
+        except ChildChunkIndexingServiceError as e:
+            raise ChildChunkIndexingError(str(e))
+        return {"data": marshal(child_chunk, child_chunk_fields)}, 200
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self, dataset_id, document_id, segment_id):
+        # check dataset
+        dataset_id = str(dataset_id)
+        dataset = DatasetService.get_dataset(dataset_id)
+        if not dataset:
+            raise NotFound("Dataset not found.")
+        # check user's model setting
+        DatasetService.check_dataset_model_setting(dataset)
+        # check document
+        document_id = str(document_id)
+        document = DocumentService.get_document(dataset_id, document_id)
+        if not document:
+            raise NotFound("Document not found.")
+        # check segment
+        segment_id = str(segment_id)
+        segment = DocumentSegment.query.filter(
+            DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
+        ).first()
+        if not segment:
+            raise NotFound("Segment not found.")
+        parser = reqparse.RequestParser()
+        parser.add_argument("limit", type=int, default=20, location="args")
+        parser.add_argument("keyword", type=str, default=None, location="args")
+        parser.add_argument("page", type=int, default=1, location="args")
+
+        args = parser.parse_args()
+
+        page = args["page"]
+        limit = min(args["limit"], 100)
+        keyword = args["keyword"]
+
+        child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
+        return {
+            "data": marshal(child_chunks.items, child_chunk_fields),
+            "total": child_chunks.total,
+            "total_pages": child_chunks.pages,
+            "page": page,
+            "limit": limit,
+        }, 200
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    @cloud_edition_billing_resource_check("vector_space")
+    def patch(self, dataset_id, document_id, segment_id):
+        # check dataset
+        dataset_id = str(dataset_id)
+        dataset = DatasetService.get_dataset(dataset_id)
+        if not dataset:
+            raise NotFound("Dataset not found.")
+        # check user's model setting
+        DatasetService.check_dataset_model_setting(dataset)
+        # check document
+        document_id = str(document_id)
+        document = DocumentService.get_document(dataset_id, document_id)
+        if not document:
+            raise NotFound("Document not found.")
+            # check segment
+        segment_id = str(segment_id)
+        segment = DocumentSegment.query.filter(
+            DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
+        ).first()
+        if not segment:
+            raise NotFound("Segment not found.")
+        # The role of the current user in the ta table must be admin, owner, or editor
+        if not current_user.is_editor:
+            raise Forbidden()
+        try:
+            DatasetService.check_dataset_permission(dataset, current_user)
+        except services.errors.account.NoPermissionError as e:
+            raise Forbidden(str(e))
+        # validate args
+        parser = reqparse.RequestParser()
+        parser.add_argument("chunks", type=list, required=True, nullable=False, location="json")
+        args = parser.parse_args()
+        try:
+            chunks = [ChildChunkUpdateArgs(**chunk) for chunk in args.get("chunks")]
+            child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
+        except ChildChunkIndexingServiceError as e:
+            raise ChildChunkIndexingError(str(e))
+        return {"data": marshal(child_chunks, child_chunk_fields)}, 200
+
+
+class ChildChunkUpdateApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
+        # check dataset
+        dataset_id = str(dataset_id)
+        dataset = DatasetService.get_dataset(dataset_id)
+        if not dataset:
+            raise NotFound("Dataset not found.")
+        # check user's model setting
+        DatasetService.check_dataset_model_setting(dataset)
+        # check document
+        document_id = str(document_id)
+        document = DocumentService.get_document(dataset_id, document_id)
+        if not document:
+            raise NotFound("Document not found.")
+        # check segment
+        segment_id = str(segment_id)
+        segment = DocumentSegment.query.filter(
+            DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
+        ).first()
+        if not segment:
+            raise NotFound("Segment not found.")
+        # check child chunk
+        child_chunk_id = str(child_chunk_id)
+        child_chunk = ChildChunk.query.filter(
+            ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id
+        ).first()
+        if not child_chunk:
+            raise NotFound("Child chunk not found.")
+        # The role of the current user in the ta table must be admin or owner
+        if not current_user.is_editor:
+            raise Forbidden()
+        try:
+            DatasetService.check_dataset_permission(dataset, current_user)
+        except services.errors.account.NoPermissionError as e:
+            raise Forbidden(str(e))
+        try:
+            SegmentService.delete_child_chunk(child_chunk, dataset)
+        except ChildChunkDeleteIndexServiceError as e:
+            raise ChildChunkDeleteIndexError(str(e))
+        return {"result": "success"}, 200
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    @cloud_edition_billing_resource_check("vector_space")
+    def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
+        # check dataset
+        dataset_id = str(dataset_id)
+        dataset = DatasetService.get_dataset(dataset_id)
+        if not dataset:
+            raise NotFound("Dataset not found.")
+        # check user's model setting
+        DatasetService.check_dataset_model_setting(dataset)
+        # check document
+        document_id = str(document_id)
+        document = DocumentService.get_document(dataset_id, document_id)
+        if not document:
+            raise NotFound("Document not found.")
+            # check segment
+        segment_id = str(segment_id)
+        segment = DocumentSegment.query.filter(
+            DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
+        ).first()
+        if not segment:
+            raise NotFound("Segment not found.")
+        # check child chunk
+        child_chunk_id = str(child_chunk_id)
+        child_chunk = ChildChunk.query.filter(
+            ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id
+        ).first()
+        if not child_chunk:
+            raise NotFound("Child chunk not found.")
+        # The role of the current user in the ta table must be admin or owner
+        if not current_user.is_editor:
+            raise Forbidden()
+        try:
+            DatasetService.check_dataset_permission(dataset, current_user)
+        except services.errors.account.NoPermissionError as e:
+            raise Forbidden(str(e))
+        # validate args
+        parser = reqparse.RequestParser()
+        parser.add_argument("content", type=str, required=True, nullable=False, location="json")
+        args = parser.parse_args()
+        try:
+            child_chunk = SegmentService.update_child_chunk(
+                args.get("content"), child_chunk, segment, document, dataset
+            )
+        except ChildChunkIndexingServiceError as e:
+            raise ChildChunkIndexingError(str(e))
+        return {"data": marshal(child_chunk, child_chunk_fields)}, 200
+
+
 api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
-api.add_resource(DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>")
+api.add_resource(
+    DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>"
+)
 api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
 api.add_resource(
     DatasetDocumentSegmentUpdateApi,
@@ -424,3 +651,11 @@ api.add_resource(
     "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import",
     "/datasets/batch_import_status/<uuid:job_id>",
 )
+api.add_resource(
+    ChildChunkAddApi,
+    "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks",
+)
+api.add_resource(
+    ChildChunkUpdateApi,
+    "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id>",
+)

+ 12 - 0
api/controllers/console/datasets/error.py

@@ -89,3 +89,15 @@ class IndexingEstimateError(BaseHTTPException):
     error_code = "indexing_estimate_error"
     description = "Knowledge indexing estimate failed: {message}"
     code = 500
+
+
+class ChildChunkIndexingError(BaseHTTPException):
+    error_code = "child_chunk_indexing_error"
+    description = "Create child chunk index failed: {message}"
+    code = 500
+
+
+class ChildChunkDeleteIndexError(BaseHTTPException):
+    error_code = "child_chunk_delete_index_error"
+    description = "Delete child chunk index failed: {message}"
+    code = 500

+ 2 - 1
api/controllers/service_api/dataset/segment.py

@@ -16,6 +16,7 @@ from extensions.ext_database import db
 from fields.segment_fields import segment_fields
 from models.dataset import Dataset, DocumentSegment
 from services.dataset_service import DatasetService, DocumentService, SegmentService
+from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs
 
 
 class SegmentApi(DatasetApiResource):
@@ -193,7 +194,7 @@ class DatasetSegmentApi(DatasetApiResource):
         args = parser.parse_args()
 
         SegmentService.segment_create_args_validate(args["segment"], document)
-        segment = SegmentService.update_segment(args["segment"], segment, document, dataset)
+        segment = SegmentService.update_segment(SegmentUpdateArgs(**args["segment"]), segment, document, dataset)
         return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
 
 

+ 19 - 0
api/core/entities/knowledge_entities.py

@@ -0,0 +1,19 @@
+from typing import Optional
+
+from pydantic import BaseModel
+
+
+class PreviewDetail(BaseModel):
+    content: str
+    child_chunks: Optional[list[str]] = None
+
+
+class QAPreviewDetail(BaseModel):
+    question: str
+    answer: str
+
+
+class IndexingEstimate(BaseModel):
+    total_segments: int
+    preview: list[PreviewDetail]
+    qa_preview: Optional[list[QAPreviewDetail]] = None

+ 68 - 212
api/core/indexing_runner.py

@@ -8,34 +8,34 @@ import time
 import uuid
 from typing import Any, Optional, cast
 
-from flask import Flask, current_app
+from flask import current_app
 from flask_login import current_user  # type: ignore
 from sqlalchemy.orm.exc import ObjectDeletedError
 
 from configs import dify_config
+from core.entities.knowledge_entities import IndexingEstimate, PreviewDetail, QAPreviewDetail
 from core.errors.error import ProviderTokenNotInitError
-from core.llm_generator.llm_generator import LLMGenerator
 from core.model_manager import ModelInstance, ModelManager
 from core.model_runtime.entities.model_entities import ModelType
 from core.rag.cleaner.clean_processor import CleanProcessor
 from core.rag.datasource.keyword.keyword_factory import Keyword
 from core.rag.docstore.dataset_docstore import DatasetDocumentStore
 from core.rag.extractor.entity.extract_setting import ExtractSetting
+from core.rag.index_processor.constant.index_type import IndexType
 from core.rag.index_processor.index_processor_base import BaseIndexProcessor
 from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
-from core.rag.models.document import Document
+from core.rag.models.document import ChildDocument, Document
 from core.rag.splitter.fixed_text_splitter import (
     EnhanceRecursiveCharacterTextSplitter,
     FixedRecursiveCharacterTextSplitter,
 )
 from core.rag.splitter.text_splitter import TextSplitter
-from core.tools.utils.text_processing_utils import remove_leading_symbols
 from core.tools.utils.web_reader_tool import get_image_upload_file_ids
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from extensions.ext_storage import storage
 from libs import helper
-from models.dataset import Dataset, DatasetProcessRule, DocumentSegment
+from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
 from models.dataset import Document as DatasetDocument
 from models.model import UploadFile
 from services.feature_service import FeatureService
@@ -115,6 +115,9 @@ class IndexingRunner:
 
             for document_segment in document_segments:
                 db.session.delete(document_segment)
+                if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
+                    # delete child chunks
+                    db.session.query(ChildChunk).filter(ChildChunk.segment_id == document_segment.id).delete()
             db.session.commit()
             # get the process rule
             processing_rule = (
@@ -183,7 +186,22 @@ class IndexingRunner:
                                 "dataset_id": document_segment.dataset_id,
                             },
                         )
-
+                        if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
+                            child_chunks = document_segment.child_chunks
+                            if child_chunks:
+                                child_documents = []
+                                for child_chunk in child_chunks:
+                                    child_document = ChildDocument(
+                                        page_content=child_chunk.content,
+                                        metadata={
+                                            "doc_id": child_chunk.index_node_id,
+                                            "doc_hash": child_chunk.index_node_hash,
+                                            "document_id": document_segment.document_id,
+                                            "dataset_id": document_segment.dataset_id,
+                                        },
+                                    )
+                                    child_documents.append(child_document)
+                                document.children = child_documents
                         documents.append(document)
 
             # build index
@@ -222,7 +240,7 @@ class IndexingRunner:
         doc_language: str = "English",
         dataset_id: Optional[str] = None,
         indexing_technique: str = "economy",
-    ) -> dict:
+    ) -> IndexingEstimate:
         """
         Estimate the indexing for the document.
         """
@@ -258,31 +276,38 @@ class IndexingRunner:
                     tenant_id=tenant_id,
                     model_type=ModelType.TEXT_EMBEDDING,
                 )
-        preview_texts: list[str] = []
+        preview_texts = []
+
         total_segments = 0
         index_type = doc_form
         index_processor = IndexProcessorFactory(index_type).init_index_processor()
-        all_text_docs = []
         for extract_setting in extract_settings:
             # extract
-            text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
-            all_text_docs.extend(text_docs)
             processing_rule = DatasetProcessRule(
                 mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"])
             )
-
-            # get splitter
-            splitter = self._get_splitter(processing_rule, embedding_model_instance)
-
-            # split to documents
-            documents = self._split_to_documents_for_estimate(
-                text_docs=text_docs, splitter=splitter, processing_rule=processing_rule
+            text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
+            documents = index_processor.transform(
+                text_docs,
+                embedding_model_instance=embedding_model_instance,
+                process_rule=processing_rule.to_dict(),
+                tenant_id=current_user.current_tenant_id,
+                doc_language=doc_language,
+                preview=True,
             )
-
             total_segments += len(documents)
             for document in documents:
-                if len(preview_texts) < 5:
-                    preview_texts.append(document.page_content)
+                if len(preview_texts) < 10:
+                    if doc_form and doc_form == "qa_model":
+                        preview_detail = QAPreviewDetail(
+                            question=document.page_content, answer=document.metadata.get("answer")
+                        )
+                        preview_texts.append(preview_detail)
+                    else:
+                        preview_detail = PreviewDetail(content=document.page_content)
+                        if document.children:
+                            preview_detail.child_chunks = [child.page_content for child in document.children]
+                        preview_texts.append(preview_detail)
 
                 # delete image files and related db records
                 image_upload_file_ids = get_image_upload_file_ids(document.page_content)
@@ -299,15 +324,8 @@ class IndexingRunner:
                     db.session.delete(image_file)
 
         if doc_form and doc_form == "qa_model":
-            if len(preview_texts) > 0:
-                # qa model document
-                response = LLMGenerator.generate_qa_document(
-                    current_user.current_tenant_id, preview_texts[0], doc_language
-                )
-                document_qa_list = self.format_split_text(response)
-
-                return {"total_segments": total_segments * 20, "qa_preview": document_qa_list, "preview": preview_texts}
-        return {"total_segments": total_segments, "preview": preview_texts}
+            return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[])
+        return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
 
     def _extract(
         self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
@@ -401,31 +419,26 @@ class IndexingRunner:
 
     @staticmethod
     def _get_splitter(
-        processing_rule: DatasetProcessRule, embedding_model_instance: Optional[ModelInstance]
+        processing_rule_mode: str,
+        max_tokens: int,
+        chunk_overlap: int,
+        separator: str,
+        embedding_model_instance: Optional[ModelInstance],
     ) -> TextSplitter:
         """
         Get the NodeParser object according to the processing rule.
         """
-        character_splitter: TextSplitter
-        if processing_rule.mode == "custom":
+        if processing_rule_mode in ["custom", "hierarchical"]:
             # The user-defined segmentation rule
-            rules = json.loads(processing_rule.rules)
-            segmentation = rules["segmentation"]
             max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
-            if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length:
+            if max_tokens < 50 or max_tokens > max_segmentation_tokens_length:
                 raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")
 
-            separator = segmentation["separator"]
             if separator:
                 separator = separator.replace("\\n", "\n")
 
-            if segmentation.get("chunk_overlap"):
-                chunk_overlap = segmentation["chunk_overlap"]
-            else:
-                chunk_overlap = 0
-
             character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
-                chunk_size=segmentation["max_tokens"],
+                chunk_size=max_tokens,
                 chunk_overlap=chunk_overlap,
                 fixed_separator=separator,
                 separators=["\n\n", "。", ". ", " ", ""],
@@ -443,142 +456,6 @@ class IndexingRunner:
 
         return character_splitter
 
-    def _step_split(
-        self,
-        text_docs: list[Document],
-        splitter: TextSplitter,
-        dataset: Dataset,
-        dataset_document: DatasetDocument,
-        processing_rule: DatasetProcessRule,
-    ) -> list[Document]:
-        """
-        Split the text documents into documents and save them to the document segment.
-        """
-        documents = self._split_to_documents(
-            text_docs=text_docs,
-            splitter=splitter,
-            processing_rule=processing_rule,
-            tenant_id=dataset.tenant_id,
-            document_form=dataset_document.doc_form,
-            document_language=dataset_document.doc_language,
-        )
-
-        # save node to document segment
-        doc_store = DatasetDocumentStore(
-            dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id
-        )
-
-        # add document segments
-        doc_store.add_documents(documents)
-
-        # update document status to indexing
-        cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
-        self._update_document_index_status(
-            document_id=dataset_document.id,
-            after_indexing_status="indexing",
-            extra_update_params={
-                DatasetDocument.cleaning_completed_at: cur_time,
-                DatasetDocument.splitting_completed_at: cur_time,
-            },
-        )
-
-        # update segment status to indexing
-        self._update_segments_by_document(
-            dataset_document_id=dataset_document.id,
-            update_params={
-                DocumentSegment.status: "indexing",
-                DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
-            },
-        )
-
-        return documents
-
-    def _split_to_documents(
-        self,
-        text_docs: list[Document],
-        splitter: TextSplitter,
-        processing_rule: DatasetProcessRule,
-        tenant_id: str,
-        document_form: str,
-        document_language: str,
-    ) -> list[Document]:
-        """
-        Split the text documents into nodes.
-        """
-        all_documents: list[Document] = []
-        all_qa_documents: list[Document] = []
-        for text_doc in text_docs:
-            # document clean
-            document_text = self._document_clean(text_doc.page_content, processing_rule)
-            text_doc.page_content = document_text
-
-            # parse document to nodes
-            documents = splitter.split_documents([text_doc])
-            split_documents = []
-            for document_node in documents:
-                if document_node.page_content.strip():
-                    if document_node.metadata is not None:
-                        doc_id = str(uuid.uuid4())
-                        hash = helper.generate_text_hash(document_node.page_content)
-                        document_node.metadata["doc_id"] = doc_id
-                        document_node.metadata["doc_hash"] = hash
-                    # delete Splitter character
-                    page_content = document_node.page_content
-                    document_node.page_content = remove_leading_symbols(page_content)
-
-                    if document_node.page_content:
-                        split_documents.append(document_node)
-            all_documents.extend(split_documents)
-        # processing qa document
-        if document_form == "qa_model":
-            for i in range(0, len(all_documents), 10):
-                threads = []
-                sub_documents = all_documents[i : i + 10]
-                for doc in sub_documents:
-                    document_format_thread = threading.Thread(
-                        target=self.format_qa_document,
-                        kwargs={
-                            "flask_app": current_app._get_current_object(),  # type: ignore
-                            "tenant_id": tenant_id,
-                            "document_node": doc,
-                            "all_qa_documents": all_qa_documents,
-                            "document_language": document_language,
-                        },
-                    )
-                    threads.append(document_format_thread)
-                    document_format_thread.start()
-                for thread in threads:
-                    thread.join()
-            return all_qa_documents
-        return all_documents
-
-    def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
-        format_documents = []
-        if document_node.page_content is None or not document_node.page_content.strip():
-            return
-        with flask_app.app_context():
-            try:
-                # qa model document
-                response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language)
-                document_qa_list = self.format_split_text(response)
-                qa_documents = []
-                for result in document_qa_list:
-                    qa_document = Document(
-                        page_content=result["question"], metadata=document_node.metadata.model_copy()
-                    )
-                    if qa_document.metadata is not None:
-                        doc_id = str(uuid.uuid4())
-                        hash = helper.generate_text_hash(result["question"])
-                        qa_document.metadata["answer"] = result["answer"]
-                        qa_document.metadata["doc_id"] = doc_id
-                        qa_document.metadata["doc_hash"] = hash
-                    qa_documents.append(qa_document)
-                format_documents.extend(qa_documents)
-            except Exception as e:
-                logging.exception("Failed to format qa document")
-
-            all_qa_documents.extend(format_documents)
-
     def _split_to_documents_for_estimate(
         self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule
     ) -> list[Document]:
@@ -624,11 +501,11 @@ class IndexingRunner:
         return document_text
 
     @staticmethod
-    def format_split_text(text):
+    def format_split_text(text: str) -> list[QAPreviewDetail]:
         regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
         matches = re.findall(regex, text, re.UNICODE)
 
-        return [{"question": q, "answer": re.sub(r"\n\s*", "\n", a.strip())} for q, a in matches if q and a]
+        return [QAPreviewDetail(question=q, answer=re.sub(r"\n\s*", "\n", a.strip())) for q, a in matches if q and a]
 
     def _load(
         self,
@@ -654,13 +531,14 @@ class IndexingRunner:
         indexing_start_at = time.perf_counter()
         tokens = 0
         chunk_size = 10
+        if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
+            # create keyword index
+            create_keyword_thread = threading.Thread(
+                target=self._process_keyword_index,
+                args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents),
+            )
+            create_keyword_thread.start()
 
-        # create keyword index
-        create_keyword_thread = threading.Thread(
-            target=self._process_keyword_index,
-            args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents),  # type: ignore
-        )
-        create_keyword_thread.start()
         if dataset.indexing_technique == "high_quality":
             with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
                 futures = []
@@ -680,8 +558,8 @@ class IndexingRunner:
 
                 for future in futures:
                     tokens += future.result()
-
-        create_keyword_thread.join()
+        if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
+            create_keyword_thread.join()
         indexing_end_at = time.perf_counter()
 
         # update document status to completed
@@ -793,28 +671,6 @@ class IndexingRunner:
         DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
         db.session.commit()
 
-    @staticmethod
-    def batch_add_segments(segments: list[DocumentSegment], dataset: Dataset):
-        """
-        Batch add segments index processing
-        """
-        documents = []
-        for segment in segments:
-            document = Document(
-                page_content=segment.content,
-                metadata={
-                    "doc_id": segment.index_node_id,
-                    "doc_hash": segment.index_node_hash,
-                    "document_id": segment.document_id,
-                    "dataset_id": segment.dataset_id,
-                },
-            )
-            documents.append(document)
-        # save vector index
-        index_type = dataset.doc_form
-        index_processor = IndexProcessorFactory(index_type).init_index_processor()
-        index_processor.load(dataset, documents)
-
     def _transform(
         self,
         index_processor: BaseIndexProcessor,
@@ -856,7 +712,7 @@ class IndexingRunner:
         )
 
         # add document segments
-        doc_store.add_documents(documents)
+        doc_store.add_documents(docs=documents, save_child=dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX)
 
         # update document status to indexing
         cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)

+ 89 - 1
api/core/rag/datasource/retrieval_service.py

@@ -6,11 +6,14 @@ from flask import Flask, current_app
 from core.rag.data_post_processor.data_post_processor import DataPostProcessor
 from core.rag.datasource.keyword.keyword_factory import Keyword
 from core.rag.datasource.vdb.vector_factory import Vector
+from core.rag.embedding.retrieval import RetrievalSegments
+from core.rag.index_processor.constant.index_type import IndexType
 from core.rag.models.document import Document
 from core.rag.rerank.rerank_type import RerankMode
 from core.rag.retrieval.retrieval_methods import RetrievalMethod
 from extensions.ext_database import db
-from models.dataset import Dataset
+from models.dataset import ChildChunk, Dataset, DocumentSegment
+from models.dataset import Document as DatasetDocument
 from services.external_knowledge_service import ExternalDatasetService
 
 default_retrieval_model = {
@@ -248,3 +251,88 @@ class RetrievalService:
     @staticmethod
     def escape_query_for_search(query: str) -> str:
         return query.replace('"', '\\"')
+
+    @staticmethod
+    def format_retrieval_documents(documents: list[Document]) -> list[RetrievalSegments]:
+        records = []
+        include_segment_ids = []
+        segment_child_map = {}
+        for document in documents:
+            document_id = document.metadata["document_id"]
+            dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
+            if dataset_document and dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
+                child_index_node_id = document.metadata["doc_id"]
+                result = (
+                    db.session.query(ChildChunk, DocumentSegment)
+                    .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
+                    .filter(
+                        ChildChunk.index_node_id == child_index_node_id,
+                        DocumentSegment.dataset_id == dataset_document.dataset_id,
+                        DocumentSegment.enabled == True,
+                        DocumentSegment.status == "completed",
+                    )
+                    .first()
+                )
+                if result:
+                    child_chunk, segment = result
+                    if not segment:
+                        continue
+                    if segment.id not in include_segment_ids:
+                        include_segment_ids.append(segment.id)
+                        child_chunk_detail = {
+                            "id": child_chunk.id,
+                            "content": child_chunk.content,
+                            "position": child_chunk.position,
+                            "score": document.metadata.get("score", 0.0),
+                        }
+                        map_detail = {
+                            "max_score": document.metadata.get("score", 0.0),
+                            "child_chunks": [child_chunk_detail],
+                        }
+                        segment_child_map[segment.id] = map_detail
+                        record = {
+                            "segment": segment,
+                        }
+                        records.append(record)
+                    else:
+                        child_chunk_detail = {
+                            "id": child_chunk.id,
+                            "content": child_chunk.content,
+                            "position": child_chunk.position,
+                            "score": document.metadata.get("score", 0.0),
+                        }
+                        segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
+                        segment_child_map[segment.id]["max_score"] = max(
+                            segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
+                        )
+                else:
+                    continue
+            else:
+                index_node_id = document.metadata["doc_id"]
+
+                segment = (
+                    db.session.query(DocumentSegment)
+                    .filter(
+                        DocumentSegment.dataset_id == dataset_document.dataset_id,
+                        DocumentSegment.enabled == True,
+                        DocumentSegment.status == "completed",
+                        DocumentSegment.index_node_id == index_node_id,
+                    )
+                    .first()
+                )
+
+                if not segment:
+                    continue
+                include_segment_ids.append(segment.id)
+                record = {
+                    "segment": segment,
+                    "score": document.metadata.get("score", None),
+                }
+
+                records.append(record)
+            for record in records:
+                if record["segment"].id in segment_child_map:
+                    record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None)
+                    record["score"] = segment_child_map[record["segment"].id]["max_score"]
+
+        return [RetrievalSegments(**record) for record in records]

+ 43 - 2
api/core/rag/docstore/dataset_docstore.py

@@ -7,7 +7,7 @@ from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
 from core.rag.models.document import Document
 from extensions.ext_database import db
-from models.dataset import Dataset, DocumentSegment
+from models.dataset import ChildChunk, Dataset, DocumentSegment
 
 
 class DatasetDocumentStore:
@@ -60,7 +60,7 @@ class DatasetDocumentStore:
 
         return output
 
-    def add_documents(self, docs: Sequence[Document], allow_update: bool = True) -> None:
+    def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False) -> None:
         max_position = (
             db.session.query(func.max(DocumentSegment.position))
             .filter(DocumentSegment.document_id == self._document_id)
@@ -120,6 +120,23 @@ class DatasetDocumentStore:
                     segment_document.answer = doc.metadata.pop("answer", "")
 
                 db.session.add(segment_document)
+                db.session.flush()
+                if save_child:
+                    for postion, child in enumerate(doc.children, start=1):
+                        child_segment = ChildChunk(
+                            tenant_id=self._dataset.tenant_id,
+                            dataset_id=self._dataset.id,
+                            document_id=self._document_id,
+                            segment_id=segment_document.id,
+                            position=postion,
+                            index_node_id=child.metadata["doc_id"],
+                            index_node_hash=child.metadata["doc_hash"],
+                            content=child.page_content,
+                            word_count=len(child.page_content),
+                            type="automatic",
+                            created_by=self._user_id,
+                        )
+                        db.session.add(child_segment)
             else:
                 segment_document.content = doc.page_content
                 if doc.metadata.get("answer"):
@@ -127,6 +144,30 @@ class DatasetDocumentStore:
                 segment_document.index_node_hash = doc.metadata["doc_hash"]
                 segment_document.word_count = len(doc.page_content)
                 segment_document.tokens = tokens
+                if save_child and doc.children:
+                    # delete the existing child chunks
+                    db.session.query(ChildChunk).filter(
+                        ChildChunk.tenant_id == self._dataset.tenant_id,
+                        ChildChunk.dataset_id == self._dataset.id,
+                        ChildChunk.document_id == self._document_id,
+                        ChildChunk.segment_id == segment_document.id,
+                    ).delete()
+                    # add new child chunks
+                    for position, child in enumerate(doc.children, start=1):
+                        child_segment = ChildChunk(
+                            tenant_id=self._dataset.tenant_id,
+                            dataset_id=self._dataset.id,
+                            document_id=self._document_id,
+                            segment_id=segment_document.id,
+                            position=position,
+                            index_node_id=child.metadata["doc_id"],
+                            index_node_hash=child.metadata["doc_hash"],
+                            content=child.page_content,
+                            word_count=len(child.page_content),
+                            type="automatic",
+                            created_by=self._user_id,
+                        )
+                        db.session.add(child_segment)
 
             db.session.commit()
 

+ 23 - 0
api/core/rag/embedding/retrieval.py

@@ -0,0 +1,23 @@
+from typing import Optional
+
+from pydantic import BaseModel
+
+from models.dataset import DocumentSegment
+
+
+class RetrievalChildChunk(BaseModel):
+    """Retrieval segments."""
+
+    id: str
+    content: str
+    score: float
+    position: int
+
+
+class RetrievalSegments(BaseModel):
+    """Retrieval segments."""
+
+    model_config = {"arbitrary_types_allowed": True}
+    segment: DocumentSegment
+    child_chunks: Optional[list[RetrievalChildChunk]] = None
+    score: Optional[float] = None

+ 1 - 6
api/core/rag/extractor/extract_processor.py

@@ -24,7 +24,6 @@ from core.rag.extractor.unstructured.unstructured_markdown_extractor import Unst
 from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor
 from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor
 from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor
-from core.rag.extractor.unstructured.unstructured_text_extractor import UnstructuredTextExtractor
 from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor
 from core.rag.extractor.word_extractor import WordExtractor
 from core.rag.models.document import Document
@@ -141,11 +140,7 @@ class ExtractProcessor:
                         extractor = UnstructuredEpubExtractor(file_path, unstructured_api_url, unstructured_api_key)
                     else:
                         # txt
-                        extractor = (
-                            UnstructuredTextExtractor(file_path, unstructured_api_url)
-                            if is_automatic
-                            else TextExtractor(file_path, autodetect_encoding=True)
-                        )
+                        extractor = TextExtractor(file_path, autodetect_encoding=True)
                 else:
                     if file_extension in {".xlsx", ".xls"}:
                         extractor = ExcelExtractor(file_path)

+ 3 - 1
api/core/rag/extractor/word_extractor.py

@@ -267,8 +267,10 @@ class WordExtractor(BaseExtractor):
                 if isinstance(element.tag, str) and element.tag.endswith("p"):  # paragraph
                     para = paragraphs.pop(0)
                     parsed_paragraph = parse_paragraph(para)
-                    if parsed_paragraph:
+                    if parsed_paragraph.strip():
                         content.append(parsed_paragraph)
+                    else:
+                        content.append("\n")
                 elif isinstance(element.tag, str) and element.tag.endswith("tbl"):  # table
                     table = tables.pop(0)
                     content.append(self._table_to_markdown(table, image_map))

+ 2 - 3
api/core/rag/index_processor/constant/index_type.py

@@ -1,8 +1,7 @@
 from enum import Enum
 
 
-class IndexType(Enum):
+class IndexType(str, Enum):
     PARAGRAPH_INDEX = "text_model"
     QA_INDEX = "qa_model"
-    PARENT_CHILD_INDEX = "parent_child_index"
-    SUMMARY_INDEX = "summary_index"
+    PARENT_CHILD_INDEX = "hierarchical_model"

+ 14 - 11
api/core/rag/index_processor/index_processor_base.py

@@ -27,10 +27,10 @@ class BaseIndexProcessor(ABC):
         raise NotImplementedError
 
     @abstractmethod
-    def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
+    def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
         raise NotImplementedError
 
-    def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
+    def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs):
         raise NotImplementedError
 
     @abstractmethod
@@ -45,26 +45,29 @@ class BaseIndexProcessor(ABC):
     ) -> list[Document]:
         raise NotImplementedError
 
-    def _get_splitter(self, processing_rule: dict, embedding_model_instance: Optional[ModelInstance]) -> TextSplitter:
+    def _get_splitter(
+        self,
+        processing_rule_mode: str,
+        max_tokens: int,
+        chunk_overlap: int,
+        separator: str,
+        embedding_model_instance: Optional[ModelInstance],
+    ) -> TextSplitter:
         """
         Get the NodeParser object according to the processing rule.
         """
-        character_splitter: TextSplitter
-        if processing_rule["mode"] == "custom":
+        if processing_rule_mode in ["custom", "hierarchical"]:
             # The user-defined segmentation rule
-            rules = processing_rule["rules"]
-            segmentation = rules["segmentation"]
             max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
-            if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length:
+            if max_tokens < 50 or max_tokens > max_segmentation_tokens_length:
                 raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")
 
-            separator = segmentation["separator"]
             if separator:
                 separator = separator.replace("\\n", "\n")
 
             character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
-                chunk_size=segmentation["max_tokens"],
-                chunk_overlap=segmentation.get("chunk_overlap", 0) or 0,
+                chunk_size=max_tokens,
+                chunk_overlap=chunk_overlap,
                 fixed_separator=separator,
                 separators=["\n\n", "。", ". ", " ", ""],
                 embedding_model_instance=embedding_model_instance,

+ 5 - 2
api/core/rag/index_processor/index_processor_factory.py

@@ -3,6 +3,7 @@
 from core.rag.index_processor.constant.index_type import IndexType
 from core.rag.index_processor.index_processor_base import BaseIndexProcessor
 from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
+from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor
 from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor
 
 
@@ -18,9 +19,11 @@ class IndexProcessorFactory:
         if not self._index_type:
             raise ValueError("Index type must be specified.")
 
-        if self._index_type == IndexType.PARAGRAPH_INDEX.value:
+        if self._index_type == IndexType.PARAGRAPH_INDEX:
             return ParagraphIndexProcessor()
-        elif self._index_type == IndexType.QA_INDEX.value:
+        elif self._index_type == IndexType.QA_INDEX:
             return QAIndexProcessor()
+        elif self._index_type == IndexType.PARENT_CHILD_INDEX:
+            return ParentChildIndexProcessor()
         else:
             raise ValueError(f"Index type {self._index_type} is not supported.")

+ 23 - 6
api/core/rag/index_processor/processor/paragraph_index_processor.py

@@ -13,21 +13,34 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
 from core.rag.models.document import Document
 from core.tools.utils.text_processing_utils import remove_leading_symbols
 from libs import helper
-from models.dataset import Dataset
+from models.dataset import Dataset, DatasetProcessRule
+from services.entities.knowledge_entities.knowledge_entities import Rule
 
 
 class ParagraphIndexProcessor(BaseIndexProcessor):
     def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
         text_docs = ExtractProcessor.extract(
-            extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic"
+            extract_setting=extract_setting,
+            is_automatic=(
+                kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical"
+            ),
         )
 
         return text_docs
 
     def transform(self, documents: list[Document], **kwargs) -> list[Document]:
+        process_rule = kwargs.get("process_rule")
+        if process_rule.get("mode") == "automatic":
+            automatic_rule = DatasetProcessRule.AUTOMATIC_RULES
+            rules = Rule(**automatic_rule)
+        else:
+            rules = Rule(**process_rule.get("rules"))
         # Split the text documents into nodes.
         splitter = self._get_splitter(
-            processing_rule=kwargs.get("process_rule", {}),
+            processing_rule_mode=process_rule.get("mode"),
+            max_tokens=rules.segmentation.max_tokens,
+            chunk_overlap=rules.segmentation.chunk_overlap,
+            separator=rules.segmentation.separator,
             embedding_model_instance=kwargs.get("embedding_model_instance"),
         )
         all_documents = []
@@ -53,15 +66,19 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
             all_documents.extend(split_documents)
         return all_documents
 
-    def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
+    def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
         if dataset.indexing_technique == "high_quality":
             vector = Vector(dataset)
             vector.create(documents)
         if with_keywords:
+            keywords_list = kwargs.get("keywords_list")
             keyword = Keyword(dataset)
-            keyword.create(documents)
+            if keywords_list and len(keywords_list) > 0:
+                keyword.add_texts(documents, keywords_list=keywords_list)
+            else:
+                keyword.add_texts(documents)
 
-    def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
+    def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs):
         if dataset.indexing_technique == "high_quality":
             vector = Vector(dataset)
             if node_ids:

+ 189 - 0
api/core/rag/index_processor/processor/parent_child_index_processor.py

@@ -0,0 +1,189 @@
+"""Paragraph index processor."""
+
+import uuid
+from typing import Optional
+
+from core.model_manager import ModelInstance
+from core.rag.cleaner.clean_processor import CleanProcessor
+from core.rag.datasource.retrieval_service import RetrievalService
+from core.rag.datasource.vdb.vector_factory import Vector
+from core.rag.extractor.entity.extract_setting import ExtractSetting
+from core.rag.extractor.extract_processor import ExtractProcessor
+from core.rag.index_processor.index_processor_base import BaseIndexProcessor
+from core.rag.models.document import ChildDocument, Document
+from extensions.ext_database import db
+from libs import helper
+from models.dataset import ChildChunk, Dataset, DocumentSegment
+from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
+
+
+class ParentChildIndexProcessor(BaseIndexProcessor):
+    def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
+        text_docs = ExtractProcessor.extract(
+            extract_setting=extract_setting,
+            is_automatic=(
+                kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical"
+            ),
+        )
+
+        return text_docs
+
+    def transform(self, documents: list[Document], **kwargs) -> list[Document]:
+        process_rule = kwargs.get("process_rule")
+        rules = Rule(**process_rule.get("rules"))
+        all_documents = []
+        if rules.parent_mode == ParentMode.PARAGRAPH:
+            # Split the text documents into nodes.
+            splitter = self._get_splitter(
+                processing_rule_mode=process_rule.get("mode"),
+                max_tokens=rules.segmentation.max_tokens,
+                chunk_overlap=rules.segmentation.chunk_overlap,
+                separator=rules.segmentation.separator,
+                embedding_model_instance=kwargs.get("embedding_model_instance"),
+            )
+            for document in documents:
+                # document clean
+                document_text = CleanProcessor.clean(document.page_content, process_rule)
+                document.page_content = document_text
+                # parse document to nodes
+                document_nodes = splitter.split_documents([document])
+                split_documents = []
+                for document_node in document_nodes:
+                    if document_node.page_content.strip():
+                        doc_id = str(uuid.uuid4())
+                        hash = helper.generate_text_hash(document_node.page_content)
+                        document_node.metadata["doc_id"] = doc_id
+                        document_node.metadata["doc_hash"] = hash
+                        # delete Splitter character
+                        page_content = document_node.page_content
+                        if page_content.startswith(".") or page_content.startswith("。"):
+                            page_content = page_content[1:].strip()
+                        else:
+                            page_content = page_content
+                        if len(page_content) > 0:
+                            document_node.page_content = page_content
+                            # parse document to child nodes
+                            child_nodes = self._split_child_nodes(
+                                document_node, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
+                            )
+                            document_node.children = child_nodes
+                            split_documents.append(document_node)
+                all_documents.extend(split_documents)
+        elif rules.parent_mode == ParentMode.FULL_DOC:
+            page_content = "\n".join([document.page_content for document in documents])
+            document = Document(page_content=page_content, metadata=documents[0].metadata)
+            # parse document to child nodes
+            child_nodes = self._split_child_nodes(
+                document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
+            )
+            document.children = child_nodes
+            doc_id = str(uuid.uuid4())
+            hash = helper.generate_text_hash(document.page_content)
+            document.metadata["doc_id"] = doc_id
+            document.metadata["doc_hash"] = hash
+            all_documents.append(document)
+
+        return all_documents
+
+    def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
+        if dataset.indexing_technique == "high_quality":
+            vector = Vector(dataset)
+            for document in documents:
+                child_documents = document.children
+                if child_documents:
+                    formatted_child_documents = [
+                        Document(**child_document.model_dump()) for child_document in child_documents
+                    ]
+                    vector.create(formatted_child_documents)
+
+    def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs):
+        # node_ids is segment's node_ids
+        if dataset.indexing_technique == "high_quality":
+            delete_child_chunks = kwargs.get("delete_child_chunks") or False
+            vector = Vector(dataset)
+            if node_ids:
+                child_node_ids = (
+                    db.session.query(ChildChunk.index_node_id)
+                    .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
+                    .filter(
+                        DocumentSegment.dataset_id == dataset.id,
+                        DocumentSegment.index_node_id.in_(node_ids),
+                        ChildChunk.dataset_id == dataset.id,
+                    )
+                    .all()
+                )
+                child_node_ids = [child_node_id[0] for child_node_id in child_node_ids]
+                vector.delete_by_ids(child_node_ids)
+                if delete_child_chunks:
+                    db.session.query(ChildChunk).filter(
+                        ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids)
+                    ).delete()
+                    db.session.commit()
+            else:
+                vector.delete()
+
+                if delete_child_chunks:
+                    db.session.query(ChildChunk).filter(ChildChunk.dataset_id == dataset.id).delete()
+                    db.session.commit()
+
+    def retrieve(
+        self,
+        retrieval_method: str,
+        query: str,
+        dataset: Dataset,
+        top_k: int,
+        score_threshold: float,
+        reranking_model: dict,
+    ) -> list[Document]:
+        # Set search parameters.
+        results = RetrievalService.retrieve(
+            retrieval_method=retrieval_method,
+            dataset_id=dataset.id,
+            query=query,
+            top_k=top_k,
+            score_threshold=score_threshold,
+            reranking_model=reranking_model,
+        )
+        # Organize results.
+        docs = []
+        for result in results:
+            metadata = result.metadata
+            metadata["score"] = result.score
+            if result.score > score_threshold:
+                doc = Document(page_content=result.page_content, metadata=metadata)
+                docs.append(doc)
+        return docs
+
+    def _split_child_nodes(
+        self,
+        document_node: Document,
+        rules: Rule,
+        process_rule_mode: str,
+        embedding_model_instance: Optional[ModelInstance],
+    ) -> list[ChildDocument]:
+        child_splitter = self._get_splitter(
+            processing_rule_mode=process_rule_mode,
+            max_tokens=rules.subchunk_segmentation.max_tokens,
+            chunk_overlap=rules.subchunk_segmentation.chunk_overlap,
+            separator=rules.subchunk_segmentation.separator,
+            embedding_model_instance=embedding_model_instance,
+        )
+        # parse document to child nodes
+        child_nodes = []
+        child_documents = child_splitter.split_documents([document_node])
+        for child_document_node in child_documents:
+            if child_document_node.page_content.strip():
+                doc_id = str(uuid.uuid4())
+                hash = helper.generate_text_hash(child_document_node.page_content)
+                child_document = ChildDocument(
+                    page_content=child_document_node.page_content, metadata=document_node.metadata
+                )
+                child_document.metadata["doc_id"] = doc_id
+                child_document.metadata["doc_hash"] = hash
+                child_page_content = child_document.page_content
+                if child_page_content.startswith(".") or child_page_content.startswith("。"):
+                    child_page_content = child_page_content[1:].strip()
+                if len(child_page_content) > 0:
+                    child_document.page_content = child_page_content
+                    child_nodes.append(child_document)
+        return child_nodes

+ 41 - 22
api/core/rag/index_processor/processor/qa_index_processor.py

@@ -21,18 +21,28 @@ from core.rag.models.document import Document
 from core.tools.utils.text_processing_utils import remove_leading_symbols
 from libs import helper
 from models.dataset import Dataset
+from services.entities.knowledge_entities.knowledge_entities import Rule
 
 
 class QAIndexProcessor(BaseIndexProcessor):
     def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
         text_docs = ExtractProcessor.extract(
-            extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic"
+            extract_setting=extract_setting,
+            is_automatic=(
+                kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical"
+            ),
         )
         return text_docs
 
     def transform(self, documents: list[Document], **kwargs) -> list[Document]:
+        preview = kwargs.get("preview")
+        process_rule = kwargs.get("process_rule")
+        rules = Rule(**process_rule.get("rules"))
         splitter = self._get_splitter(
-            processing_rule=kwargs.get("process_rule") or {},
+            processing_rule_mode=process_rule.get("mode"),
+            max_tokens=rules.segmentation.max_tokens,
+            chunk_overlap=rules.segmentation.chunk_overlap,
+            separator=rules.segmentation.separator,
             embedding_model_instance=kwargs.get("embedding_model_instance"),
         )
 
@@ -59,24 +69,33 @@ class QAIndexProcessor(BaseIndexProcessor):
                     document_node.page_content = remove_leading_symbols(page_content)
                     split_documents.append(document_node)
             all_documents.extend(split_documents)
-        for i in range(0, len(all_documents), 10):
-            threads = []
-            sub_documents = all_documents[i : i + 10]
-            for doc in sub_documents:
-                document_format_thread = threading.Thread(
-                    target=self._format_qa_document,
-                    kwargs={
-                        "flask_app": current_app._get_current_object(),  # type: ignore
-                        "tenant_id": kwargs.get("tenant_id"),
-                        "document_node": doc,
-                        "all_qa_documents": all_qa_documents,
-                        "document_language": kwargs.get("doc_language", "English"),
-                    },
-                )
-                threads.append(document_format_thread)
-                document_format_thread.start()
-            for thread in threads:
-                thread.join()
+        if preview:
+            self._format_qa_document(
+                current_app._get_current_object(),
+                kwargs.get("tenant_id"),
+                all_documents[0],
+                all_qa_documents,
+                kwargs.get("doc_language", "English"),
+            )
+        else:
+            for i in range(0, len(all_documents), 10):
+                threads = []
+                sub_documents = all_documents[i : i + 10]
+                for doc in sub_documents:
+                    document_format_thread = threading.Thread(
+                        target=self._format_qa_document,
+                        kwargs={
+                            "flask_app": current_app._get_current_object(),
+                            "tenant_id": kwargs.get("tenant_id"),
+                            "document_node": doc,
+                            "all_qa_documents": all_qa_documents,
+                            "document_language": kwargs.get("doc_language", "English"),
+                        },
+                    )
+                    threads.append(document_format_thread)
+                    document_format_thread.start()
+                for thread in threads:
+                    thread.join()
         return all_qa_documents
 
     def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]:
@@ -98,12 +117,12 @@ class QAIndexProcessor(BaseIndexProcessor):
             raise ValueError(str(e))
         return text_docs
 
-    def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
+    def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
         if dataset.indexing_technique == "high_quality":
             vector = Vector(dataset)
             vector.create(documents)
 
-    def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
+    def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs):
         vector = Vector(dataset)
         if node_ids:
             vector.delete_by_ids(node_ids)

+ 15 - 0
api/core/rag/models/document.py

@@ -5,6 +5,19 @@ from typing import Any, Optional
 from pydantic import BaseModel, Field
 
 
+class ChildDocument(BaseModel):
+    """Class for storing a piece of text and associated metadata."""
+
+    page_content: str
+
+    vector: Optional[list[float]] = None
+
+    """Arbitrary metadata about the page content (e.g., source, relationships to other
+        documents, etc.).
+    """
+    metadata: Optional[dict] = Field(default_factory=dict)
+
+
 class Document(BaseModel):
     """Class for storing a piece of text and associated metadata."""
 
@@ -19,6 +32,8 @@ class Document(BaseModel):
 
     provider: Optional[str] = "dify"
 
+    children: Optional[list[ChildDocument]] = None
+
 
 class BaseDocumentTransformer(ABC):
     """Abstract base class for document transformation systems.

+ 9 - 23
api/core/rag/retrieval/dataset_retrieval.py

@@ -166,43 +166,29 @@ class DatasetRetrieval:
                 "content": item.page_content,
             }
             retrieval_resource_list.append(source)
-        document_score_list = {}
         # deal with dify documents
         if dify_documents:
-            for item in dify_documents:
-                if item.metadata.get("score"):
-                    document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
-
-            index_node_ids = [document.metadata["doc_id"] for document in dify_documents]
-            segments = DocumentSegment.query.filter(
-                DocumentSegment.dataset_id.in_(dataset_ids),
-                DocumentSegment.status == "completed",
-                DocumentSegment.enabled == True,
-                DocumentSegment.index_node_id.in_(index_node_ids),
-            ).all()
-
-            if segments:
-                index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
-                sorted_segments = sorted(
-                    segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
-                )
-                for segment in sorted_segments:
+            records = RetrievalService.format_retrieval_documents(dify_documents)
+            if records:
+                for record in records:
+                    segment = record.segment
                     if segment.answer:
                         document_context_list.append(
                             DocumentContext(
                                 content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
-                                score=document_score_list.get(segment.index_node_id, None),
+                                score=record.score,
                             )
                         )
                     else:
                         document_context_list.append(
                             DocumentContext(
                                 content=segment.get_sign_content(),
-                                score=document_score_list.get(segment.index_node_id, None),
+                                score=record.score,
                             )
                         )
                 if show_retrieve_source:
-                    for segment in sorted_segments:
+                    for record in records:
+                        segment = record.segment
                         dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
                         document = DatasetDocument.query.filter(
                             DatasetDocument.id == segment.document_id,
@@ -218,7 +204,7 @@ class DatasetRetrieval:
                                 "data_source_type": document.data_source_type,
                                 "segment_id": segment.id,
                                 "retriever_from": invoke_from.to_source(),
-                                "score": document_score_list.get(segment.index_node_id, 0.0),
+                                "score": record.score or 0.0,
                             }
 
                             if invoke_from.to_source() == "dev":

+ 8 - 26
api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py

@@ -11,6 +11,7 @@ from core.entities.model_entities import ModelStatus
 from core.model_manager import ModelInstance, ModelManager
 from core.model_runtime.entities.model_entities import ModelFeature, ModelType
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.rag.datasource.retrieval_service import RetrievalService
 from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
 from core.rag.retrieval.retrieval_methods import RetrievalMethod
 from core.variables import StringSegment
@@ -18,7 +19,7 @@ from core.workflow.entities.node_entities import NodeRunResult
 from core.workflow.nodes.base import BaseNode
 from core.workflow.nodes.enums import NodeType
 from extensions.ext_database import db
-from models.dataset import Dataset, Document, DocumentSegment
+from models.dataset import Dataset, Document
 from models.workflow import WorkflowNodeExecutionStatus
 
 from .entities import KnowledgeRetrievalNodeData
@@ -211,29 +212,12 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
                 "content": item.page_content,
             }
             retrieval_resource_list.append(source)
-        document_score_list: dict[str, float] = {}
         # deal with dify documents
         if dify_documents:
-            document_score_list = {}
-            for item in dify_documents:
-                if item.metadata.get("score"):
-                    document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
-
-            index_node_ids = [document.metadata["doc_id"] for document in dify_documents]
-            segments = DocumentSegment.query.filter(
-                DocumentSegment.dataset_id.in_(dataset_ids),
-                DocumentSegment.completed_at.isnot(None),
-                DocumentSegment.status == "completed",
-                DocumentSegment.enabled == True,
-                DocumentSegment.index_node_id.in_(index_node_ids),
-            ).all()
-            if segments:
-                index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
-                sorted_segments = sorted(
-                    segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
-                )
-
-                for segment in sorted_segments:
+            records = RetrievalService.format_retrieval_documents(dify_documents)
+            if records:
+                for record in records:
+                    segment = record.segment
                     dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
                     document = Document.query.filter(
                         Document.id == segment.document_id,
@@ -251,7 +235,7 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
                                 "document_data_source_type": document.data_source_type,
                                 "segment_id": segment.id,
                                 "retriever_from": "workflow",
-                                "score": document_score_list.get(segment.index_node_id, None),
+                                "score": record.score or 0.0,
                                 "segment_hit_count": segment.hit_count,
                                 "segment_word_count": segment.word_count,
                                 "segment_position": segment.position,
@@ -270,10 +254,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
                 key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0,
                 reverse=True,
             )
-            position = 1
-            for item in retrieval_resource_list:
+            for position, item in enumerate(retrieval_resource_list, start=1):
                 item["metadata"]["position"] = position
-                position += 1
         return retrieval_resource_list
 
     @classmethod

+ 1 - 0
api/fields/dataset_fields.py

@@ -73,6 +73,7 @@ dataset_detail_fields = {
     "embedding_available": fields.Boolean,
     "retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields),
     "tags": fields.List(fields.Nested(tag_fields)),
+    "doc_form": fields.String,
     "external_knowledge_info": fields.Nested(external_knowledge_info_fields),
     "external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True),
 }

+ 1 - 0
api/fields/document_fields.py

@@ -34,6 +34,7 @@ document_with_segments_fields = {
     "data_source_info": fields.Raw(attribute="data_source_info_dict"),
     "data_source_detail_dict": fields.Raw(attribute="data_source_detail_dict"),
     "dataset_process_rule_id": fields.String,
+    "process_rule_dict": fields.Raw(attribute="process_rule_dict"),
     "name": fields.String,
     "created_from": fields.String,
     "created_by": fields.String,

+ 8 - 0
api/fields/hit_testing_fields.py

@@ -34,8 +34,16 @@ segment_fields = {
     "document": fields.Nested(document_fields),
 }
 
+child_chunk_fields = {
+    "id": fields.String,
+    "content": fields.String,
+    "position": fields.Integer,
+    "score": fields.Float,
+}
+
 hit_testing_record_fields = {
     "segment": fields.Nested(segment_fields),
+    "child_chunks": fields.List(fields.Nested(child_chunk_fields)),
     "score": fields.Float,
     "tsne_position": fields.Raw,
 }

+ 14 - 0
api/fields/segment_fields.py

@@ -2,6 +2,17 @@ from flask_restful import fields  # type: ignore
 
 from libs.helper import TimestampField
 
+child_chunk_fields = {
+    "id": fields.String,
+    "segment_id": fields.String,
+    "content": fields.String,
+    "position": fields.Integer,
+    "word_count": fields.Integer,
+    "type": fields.String,
+    "created_at": TimestampField,
+    "updated_at": TimestampField,
+}
+
 segment_fields = {
     "id": fields.String,
     "position": fields.Integer,
@@ -20,10 +31,13 @@ segment_fields = {
     "status": fields.String,
     "created_by": fields.String,
     "created_at": TimestampField,
+    "updated_at": TimestampField,
+    "updated_by": fields.String,
     "indexing_at": TimestampField,
     "completed_at": TimestampField,
     "error": fields.String,
     "stopped_at": TimestampField,
+    "child_chunks": fields.List(fields.Nested(child_chunk_fields)),
 }
 
 segment_list_response = {

+ 55 - 0
api/migrations/versions/2024_11_22_0701-e19037032219_parent_child_index.py

@@ -0,0 +1,55 @@
+"""parent-child-index
+
+Revision ID: e19037032219
+Revises: 01d6889832f7
+Create Date: 2024-11-22 07:01:17.550037
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = 'e19037032219'
+down_revision = 'd7999dfa4aae'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    op.create_table('child_chunks',
+    sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+    sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
+    sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
+    sa.Column('document_id', models.types.StringUUID(), nullable=False),
+    sa.Column('segment_id', models.types.StringUUID(), nullable=False),
+    sa.Column('position', sa.Integer(), nullable=False),
+    sa.Column('content', sa.Text(), nullable=False),
+    sa.Column('word_count', sa.Integer(), nullable=False),
+    sa.Column('index_node_id', sa.String(length=255), nullable=True),
+    sa.Column('index_node_hash', sa.String(length=255), nullable=True),
+    sa.Column('type', sa.String(length=255), server_default=sa.text("'automatic'::character varying"), nullable=False),
+    sa.Column('created_by', models.types.StringUUID(), nullable=False),
+    sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
+    sa.Column('updated_by', models.types.StringUUID(), nullable=True),
+    sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
+    sa.Column('indexing_at', sa.DateTime(), nullable=True),
+    sa.Column('completed_at', sa.DateTime(), nullable=True),
+    sa.Column('error', sa.Text(), nullable=True),
+    sa.PrimaryKeyConstraint('id', name='child_chunk_pkey')
+    )
+    with op.batch_alter_table('child_chunks', schema=None) as batch_op:
+        batch_op.create_index('child_chunk_dataset_id_idx', ['tenant_id', 'dataset_id', 'document_id', 'segment_id', 'index_node_id'], unique=False)
+
+    # ### end Alembic commands ###
+
+
+def downgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('child_chunks', schema=None) as batch_op:
+        batch_op.drop_index('child_chunk_dataset_id_idx')
+
+    op.drop_table('child_chunks')
+    # ### end Alembic commands ###

+ 47 - 0
api/migrations/versions/2024_12_25_1137-923752d42eb6_add_auto_disabled_dataset_logs.py

@@ -0,0 +1,47 @@
+"""add_auto_disabled_dataset_logs
+
+Revision ID: 923752d42eb6
+Revises: e19037032219
+Create Date: 2024-12-25 11:37:55.467101
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+from sqlalchemy.dialects import postgresql
+
+# revision identifiers, used by Alembic.
+revision = '923752d42eb6'
+down_revision = 'e19037032219'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    op.create_table('dataset_auto_disable_logs',
+    sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+    sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
+    sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
+    sa.Column('document_id', models.types.StringUUID(), nullable=False),
+    sa.Column('notified', sa.Boolean(), server_default=sa.text('false'), nullable=False),
+    sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
+    sa.PrimaryKeyConstraint('id', name='dataset_auto_disable_log_pkey')
+    )
+    with op.batch_alter_table('dataset_auto_disable_logs', schema=None) as batch_op:
+        batch_op.create_index('dataset_auto_disable_log_created_atx', ['created_at'], unique=False)
+        batch_op.create_index('dataset_auto_disable_log_dataset_idx', ['dataset_id'], unique=False)
+        batch_op.create_index('dataset_auto_disable_log_tenant_idx', ['tenant_id'], unique=False)
+
+    # ### end Alembic commands ###
+
+
+def downgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('dataset_auto_disable_logs', schema=None) as batch_op:
+        batch_op.drop_index('dataset_auto_disable_log_tenant_idx')
+        batch_op.drop_index('dataset_auto_disable_log_dataset_idx')
+        batch_op.drop_index('dataset_auto_disable_log_created_atx')
+
+    op.drop_table('dataset_auto_disable_logs')
+    # ### end Alembic commands ###

+ 84 - 3
api/models/dataset.py

@@ -17,6 +17,7 @@ from sqlalchemy.dialects.postgresql import JSONB
 from configs import dify_config
 from core.rag.retrieval.retrieval_methods import RetrievalMethod
 from extensions.ext_storage import storage
+from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
 
 from .account import Account
 from .engine import db
@@ -215,7 +216,7 @@ class DatasetProcessRule(db.Model):  # type: ignore[name-defined]
     created_by = db.Column(StringUUID, nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
 
-    MODES = ["automatic", "custom"]
+    MODES = ["automatic", "custom", "hierarchical"]
     PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"]
     AUTOMATIC_RULES: dict[str, Any] = {
         "pre_processing_rules": [
@@ -231,8 +232,6 @@ class DatasetProcessRule(db.Model):  # type: ignore[name-defined]
             "dataset_id": self.dataset_id,
             "mode": self.mode,
             "rules": self.rules_dict,
-            "created_by": self.created_by,
-            "created_at": self.created_at,
         }
 
     @property
@@ -396,6 +395,12 @@ class Document(db.Model):  # type: ignore[name-defined]
             .scalar()
         )
 
+    @property
+    def process_rule_dict(self):
+        if self.dataset_process_rule_id:
+            return self.dataset_process_rule.to_dict()
+        return None
+
     def to_dict(self):
         return {
             "id": self.id,
@@ -560,6 +565,24 @@ class DocumentSegment(db.Model):  # type: ignore[name-defined]
             .first()
         )
 
+    @property
+    def child_chunks(self):
+        process_rule = self.document.dataset_process_rule
+        if process_rule.mode == "hierarchical":
+            rules = Rule(**process_rule.rules_dict)
+            if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
+                child_chunks = (
+                    db.session.query(ChildChunk)
+                    .filter(ChildChunk.segment_id == self.id)
+                    .order_by(ChildChunk.position.asc())
+                    .all()
+                )
+                return child_chunks or []
+            else:
+                return []
+        else:
+            return []
+
     def get_sign_content(self):
         signed_urls = []
         text = self.content
@@ -605,6 +628,47 @@ class DocumentSegment(db.Model):  # type: ignore[name-defined]
         return text
 
 
+class ChildChunk(db.Model):
+    __tablename__ = "child_chunks"
+    __table_args__ = (
+        db.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
+        db.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"),
+    )
+
+    # initial fields
+    id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
+    tenant_id = db.Column(StringUUID, nullable=False)
+    dataset_id = db.Column(StringUUID, nullable=False)
+    document_id = db.Column(StringUUID, nullable=False)
+    segment_id = db.Column(StringUUID, nullable=False)
+    position = db.Column(db.Integer, nullable=False)
+    content = db.Column(db.Text, nullable=False)
+    word_count = db.Column(db.Integer, nullable=False)
+    # indexing fields
+    index_node_id = db.Column(db.String(255), nullable=True)
+    index_node_hash = db.Column(db.String(255), nullable=True)
+    type = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying"))
+    created_by = db.Column(StringUUID, nullable=False)
+    created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
+    updated_by = db.Column(StringUUID, nullable=True)
+    updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
+    indexing_at = db.Column(db.DateTime, nullable=True)
+    completed_at = db.Column(db.DateTime, nullable=True)
+    error = db.Column(db.Text, nullable=True)
+
+    @property
+    def dataset(self):
+        return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first()
+
+    @property
+    def document(self):
+        return db.session.query(Document).filter(Document.id == self.document_id).first()
+
+    @property
+    def segment(self):
+        return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first()
+
+
 class AppDatasetJoin(db.Model):  # type: ignore[name-defined]
     __tablename__ = "app_dataset_joins"
     __table_args__ = (
@@ -844,3 +908,20 @@ class ExternalKnowledgeBindings(db.Model):  # type: ignore[name-defined]
     created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
     updated_by = db.Column(StringUUID, nullable=True)
     updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+
+
+class DatasetAutoDisableLog(db.Model):
+    __tablename__ = "dataset_auto_disable_logs"
+    __table_args__ = (
+        db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"),
+        db.Index("dataset_auto_disable_log_tenant_idx", "tenant_id"),
+        db.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"),
+        db.Index("dataset_auto_disable_log_created_atx", "created_at"),
+    )
+
+    id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+    tenant_id = db.Column(StringUUID, nullable=False)
+    dataset_id = db.Column(StringUUID, nullable=False)
+    document_id = db.Column(StringUUID, nullable=False)
+    notified = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
+    created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))

+ 35 - 1
api/schedule/clean_unused_datasets_task.py

@@ -10,7 +10,7 @@ from configs import dify_config
 from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
-from models.dataset import Dataset, DatasetQuery, Document
+from models.dataset import Dataset, DatasetAutoDisableLog, DatasetQuery, Document
 from services.feature_service import FeatureService
 
 
@@ -75,6 +75,23 @@ def clean_unused_datasets_task():
             )
             if not dataset_query or len(dataset_query) == 0:
                 try:
+                    # add auto disable log
+                    documents = (
+                        db.session.query(Document)
+                        .filter(
+                            Document.dataset_id == dataset.id,
+                            Document.enabled == True,
+                            Document.archived == False,
+                        )
+                        .all()
+                    )
+                    for document in documents:
+                        dataset_auto_disable_log = DatasetAutoDisableLog(
+                            tenant_id=dataset.tenant_id,
+                            dataset_id=dataset.id,
+                            document_id=document.id,
+                        )
+                        db.session.add(dataset_auto_disable_log)
                     # remove index
                     index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
                     index_processor.clean(dataset, None)
@@ -151,6 +168,23 @@ def clean_unused_datasets_task():
                     else:
                         plan = plan_cache.decode()
                     if plan == "sandbox":
+                        # add auto disable log
+                        documents = (
+                            db.session.query(Document)
+                            .filter(
+                                Document.dataset_id == dataset.id,
+                                Document.enabled == True,
+                                Document.archived == False,
+                            )
+                            .all()
+                        )
+                        for document in documents:
+                            dataset_auto_disable_log = DatasetAutoDisableLog(
+                                tenant_id=dataset.tenant_id,
+                                dataset_id=dataset.id,
+                                document_id=document.id,
+                            )
+                            db.session.add(dataset_auto_disable_log)
                         # remove index
                         index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
                         index_processor.clean(dataset, None)

+ 66 - 0
api/schedule/mail_clean_document_notify_task.py

@@ -0,0 +1,66 @@
+import logging
+import time
+
+import click
+from celery import shared_task
+from flask import render_template
+
+from extensions.ext_mail import mail
+from models.account import Account, Tenant, TenantAccountJoin
+from models.dataset import Dataset, DatasetAutoDisableLog
+
+
+@shared_task(queue="mail")
+def send_document_clean_notify_task():
+    """
+    Async Send document clean notify mail
+
+    Usage: send_document_clean_notify_task.delay()
+    """
+    if not mail.is_inited():
+        return
+
+    logging.info(click.style("Start send document clean notify mail", fg="green"))
+    start_at = time.perf_counter()
+
+    # send document clean notify mail
+    try:
+        dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all()
+        # group by tenant_id
+        dataset_auto_disable_logs_map = {}
+        for dataset_auto_disable_log in dataset_auto_disable_logs:
+            dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log)
+
+        for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items():
+            knowledge_details = []
+            tenant = Tenant.query.filter(Tenant.id == tenant_id).first()
+            if not tenant:
+                continue
+            current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first()
+            account = Account.query.filter(Account.id == current_owner_join.account_id).first()
+            if not account:
+                continue
+
+            dataset_auto_dataset_map = {}
+            for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
+                dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append(
+                    dataset_auto_disable_log.document_id
+                )
+
+            for dataset_id, document_ids in dataset_auto_dataset_map.items():
+                dataset = Dataset.query.filter(Dataset.id == dataset_id).first()
+                if dataset:
+                    document_count = len(document_ids)
+                    knowledge_details.append(f"<li>Knowledge base {dataset.name}: {document_count} documents</li>")
+
+        html_content = render_template(
+            "clean_document_job_mail_template-US.html",
+        )
+        mail.send(to=to, subject="立即加入 Dify 工作空间", html=html_content)
+
+        end_at = time.perf_counter()
+        logging.info(
+            click.style("Send document clean notify mail succeeded: latency: {}".format(end_at - start_at), fg="green")
+        )
+    except Exception:
+        logging.exception("Send invite member mail to {} failed".format(to))

+ 522 - 218
api/services/dataset_service.py

@@ -14,6 +14,7 @@ from configs import dify_config
 from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
+from core.rag.index_processor.constant.index_type import IndexType
 from core.rag.retrieval.retrieval_methods import RetrievalMethod
 from events.dataset_event import dataset_was_deleted
 from events.document_event import document_was_deleted
@@ -23,7 +24,9 @@ from libs import helper
 from models.account import Account, TenantAccountRole
 from models.dataset import (
     AppDatasetJoin,
+    ChildChunk,
     Dataset,
+    DatasetAutoDisableLog,
     DatasetCollectionBinding,
     DatasetPermission,
     DatasetPermissionEnum,
@@ -35,8 +38,14 @@ from models.dataset import (
 )
 from models.model import UploadFile
 from models.source import DataSourceOauthBinding
-from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateEntity
-from services.errors.account import NoPermissionError
+from services.entities.knowledge_entities.knowledge_entities import (
+    ChildChunkUpdateArgs,
+    KnowledgeConfig,
+    RetrievalModel,
+    SegmentUpdateArgs,
+)
+from services.errors.account import InvalidActionError, NoPermissionError
+from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
 from services.errors.dataset import DatasetNameDuplicateError
 from services.errors.document import DocumentIndexingError
 from services.errors.file import FileNotExistsError
@@ -44,13 +53,16 @@ from services.external_knowledge_service import ExternalDatasetService
 from services.feature_service import FeatureModel, FeatureService
 from services.tag_service import TagService
 from services.vector_service import VectorService
+from tasks.batch_clean_document_task import batch_clean_document_task
 from tasks.clean_notion_document_task import clean_notion_document_task
 from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
 from tasks.delete_segment_from_index_task import delete_segment_from_index_task
 from tasks.disable_segment_from_index_task import disable_segment_from_index_task
+from tasks.disable_segments_from_index_task import disable_segments_from_index_task
 from tasks.document_indexing_task import document_indexing_task
 from tasks.document_indexing_update_task import document_indexing_update_task
 from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
+from tasks.enable_segments_to_index_task import enable_segments_to_index_task
 from tasks.recover_document_indexing_task import recover_document_indexing_task
 from tasks.retry_document_indexing_task import retry_document_indexing_task
 from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task
@@ -408,6 +420,24 @@ class DatasetService:
             .all()
         )
 
+    @staticmethod
+    def get_dataset_auto_disable_logs(dataset_id: str) -> dict:
+        # get recent 30 days auto disable logs
+        start_date = datetime.datetime.now() - datetime.timedelta(days=30)
+        dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(
+            DatasetAutoDisableLog.dataset_id == dataset_id,
+            DatasetAutoDisableLog.created_at >= start_date,
+        ).all()
+        if dataset_auto_disable_logs:
+            return {
+                "document_ids": [log.document_id for log in dataset_auto_disable_logs],
+                "count": len(dataset_auto_disable_logs),
+            }
+        return {
+            "document_ids": [],
+            "count": 0,
+        }
+
 
 class DocumentService:
     DEFAULT_RULES = {
@@ -588,6 +618,20 @@ class DocumentService:
         db.session.delete(document)
         db.session.commit()
 
+    @staticmethod
+    def delete_documents(dataset: Dataset, document_ids: list[str]):
+        documents = db.session.query(Document).filter(Document.id.in_(document_ids)).all()
+        file_ids = [
+            document.data_source_info_dict["upload_file_id"]
+            for document in documents
+            if document.data_source_type == "upload_file"
+        ]
+        batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids)
+
+        for document in documents:
+            db.session.delete(document)
+        db.session.commit()
+
     @staticmethod
     def rename_document(dataset_id: str, document_id: str, name: str) -> Document:
         dataset = DatasetService.get_dataset(dataset_id)
@@ -689,7 +733,7 @@ class DocumentService:
     @staticmethod
     def save_document_with_dataset_id(
         dataset: Dataset,
-        document_data: dict,
+        knowledge_config: KnowledgeConfig,
         account: Account | Any,
         dataset_process_rule: Optional[DatasetProcessRule] = None,
         created_from: str = "web",
@@ -698,18 +742,18 @@ class DocumentService:
         features = FeatureService.get_features(current_user.current_tenant_id)
 
         if features.billing.enabled:
-            if "original_document_id" not in document_data or not document_data["original_document_id"]:
+            if not knowledge_config.original_document_id:
                 count = 0
-                if document_data["data_source"]["type"] == "upload_file":
-                    upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"]
+                if knowledge_config.data_source.info_list.data_source_type == "upload_file":
+                    upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
                     count = len(upload_file_list)
-                elif document_data["data_source"]["type"] == "notion_import":
-                    notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"]
+                elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
+                    notion_info_list = knowledge_config.data_source.info_list.notion_info_list
                     for notion_info in notion_info_list:
-                        count = count + len(notion_info["pages"])
-                elif document_data["data_source"]["type"] == "website_crawl":
-                    website_info = document_data["data_source"]["info_list"]["website_info_list"]
-                    count = len(website_info["urls"])
+                        count = count + len(notion_info.pages)
+                elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
+                    website_info = knowledge_config.data_source.info_list.website_info_list
+                    count = len(website_info.urls)
                 batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
                 if count > batch_upload_limit:
                     raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
@@ -718,17 +762,14 @@ class DocumentService:
 
         # if dataset is empty, update dataset data_source_type
         if not dataset.data_source_type:
-            dataset.data_source_type = document_data["data_source"]["type"]
+            dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type
 
         if not dataset.indexing_technique:
-            if (
-                "indexing_technique" not in document_data
-                or document_data["indexing_technique"] not in Dataset.INDEXING_TECHNIQUE_LIST
-            ):
-                raise ValueError("Indexing technique is required")
+            if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
+                raise ValueError("Indexing technique is invalid")
 
-            dataset.indexing_technique = document_data["indexing_technique"]
-            if document_data["indexing_technique"] == "high_quality":
+            dataset.indexing_technique = knowledge_config.indexing_technique
+            if knowledge_config.indexing_technique == "high_quality":
                 model_manager = ModelManager()
                 embedding_model = model_manager.get_default_model_instance(
                     tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
@@ -748,29 +789,29 @@ class DocumentService:
                         "score_threshold_enabled": False,
                     }
 
-                    dataset.retrieval_model = document_data.get("retrieval_model") or default_retrieval_model
+                    dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model
 
         documents = []
-        if document_data.get("original_document_id"):
-            document = DocumentService.update_document_with_dataset_id(dataset, document_data, account)
+        if knowledge_config.original_document_id:
+            document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account)
             documents.append(document)
             batch = document.batch
         else:
             batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
             # save process rule
             if not dataset_process_rule:
-                process_rule = document_data["process_rule"]
-                if process_rule["mode"] == "custom":
+                process_rule = knowledge_config.process_rule
+                if process_rule.mode in ("custom", "hierarchical"):
                     dataset_process_rule = DatasetProcessRule(
                         dataset_id=dataset.id,
-                        mode=process_rule["mode"],
-                        rules=json.dumps(process_rule["rules"]),
+                        mode=process_rule.mode,
+                        rules=process_rule.rules.model_dump_json(),
                         created_by=account.id,
                     )
-                elif process_rule["mode"] == "automatic":
+                elif process_rule.mode == "automatic":
                     dataset_process_rule = DatasetProcessRule(
                         dataset_id=dataset.id,
-                        mode=process_rule["mode"],
+                        mode=process_rule.mode,
                         rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
                         created_by=account.id,
                     )
@@ -786,8 +827,8 @@ class DocumentService:
                 position = DocumentService.get_documents_position(dataset.id)
                 document_ids = []
                 duplicate_document_ids = []
-                if document_data["data_source"]["type"] == "upload_file":
-                    upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"]
+                if knowledge_config.data_source.info_list.data_source_type == "upload_file":
+                    upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
                     for file_id in upload_file_list:
                         file = (
                             db.session.query(UploadFile)
@@ -804,7 +845,7 @@ class DocumentService:
                             "upload_file_id": file_id,
                         }
                         # check duplicate
-                        if document_data.get("duplicate", False):
+                        if knowledge_config.duplicate:
                             document = Document.query.filter_by(
                                 dataset_id=dataset.id,
                                 tenant_id=current_user.current_tenant_id,
@@ -814,10 +855,10 @@ class DocumentService:
                             ).first()
                             if document:
                                 document.dataset_process_rule_id = dataset_process_rule.id
-                                document.updated_at = datetime.datetime.utcnow()
+                                document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
                                 document.created_from = created_from
-                                document.doc_form = document_data["doc_form"]
-                                document.doc_language = document_data["doc_language"]
+                                document.doc_form = knowledge_config.doc_form
+                                document.doc_language = knowledge_config.doc_language
                                 document.data_source_info = json.dumps(data_source_info)
                                 document.batch = batch
                                 document.indexing_status = "waiting"
@@ -828,9 +869,9 @@ class DocumentService:
                         document = DocumentService.build_document(
                             dataset,
                             dataset_process_rule.id,
-                            document_data["data_source"]["type"],
-                            document_data["doc_form"],
-                            document_data["doc_language"],
+                            knowledge_config.data_source.info_list.data_source_type,
+                            knowledge_config.doc_form,
+                            knowledge_config.doc_language,
                             data_source_info,
                             created_from,
                             position,
@@ -843,8 +884,8 @@ class DocumentService:
                         document_ids.append(document.id)
                         documents.append(document)
                         position += 1
-                elif document_data["data_source"]["type"] == "notion_import":
-                    notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"]
+                elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
+                    notion_info_list = knowledge_config.data_source.info_list.notion_info_list
                     exist_page_ids = []
                     exist_document = {}
                     documents = Document.query.filter_by(
@@ -859,7 +900,7 @@ class DocumentService:
                             exist_page_ids.append(data_source_info["notion_page_id"])
                             exist_document[data_source_info["notion_page_id"]] = document.id
                     for notion_info in notion_info_list:
-                        workspace_id = notion_info["workspace_id"]
+                        workspace_id = notion_info.workspace_id
                         data_source_binding = DataSourceOauthBinding.query.filter(
                             db.and_(
                                 DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
@@ -870,25 +911,25 @@ class DocumentService:
                         ).first()
                         if not data_source_binding:
                             raise ValueError("Data source binding not found.")
-                        for page in notion_info["pages"]:
-                            if page["page_id"] not in exist_page_ids:
+                        for page in notion_info.pages:
+                            if page.page_id not in exist_page_ids:
                                 data_source_info = {
                                     "notion_workspace_id": workspace_id,
-                                    "notion_page_id": page["page_id"],
-                                    "notion_page_icon": page["page_icon"],
-                                    "type": page["type"],
+                                    "notion_page_id": page.page_id,
+                                    "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None,
+                                    "type": page.type,
                                 }
                                 document = DocumentService.build_document(
                                     dataset,
                                     dataset_process_rule.id,
-                                    document_data["data_source"]["type"],
-                                    document_data["doc_form"],
-                                    document_data["doc_language"],
+                                    knowledge_config.data_source.info_list.data_source_type,
+                                    knowledge_config.doc_form,
+                                    knowledge_config.doc_language,
                                     data_source_info,
                                     created_from,
                                     position,
                                     account,
-                                    page["page_name"],
+                                    page.page_name,
                                     batch,
                                 )
                                 db.session.add(document)
@@ -897,19 +938,19 @@ class DocumentService:
                                 documents.append(document)
                                 position += 1
                             else:
-                                exist_document.pop(page["page_id"])
+                                exist_document.pop(page.page_id)
                     # delete not selected documents
                     if len(exist_document) > 0:
                         clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
-                elif document_data["data_source"]["type"] == "website_crawl":
-                    website_info = document_data["data_source"]["info_list"]["website_info_list"]
-                    urls = website_info["urls"]
+                elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
+                    website_info = knowledge_config.data_source.info_list.website_info_list
+                    urls = website_info.urls
                     for url in urls:
                         data_source_info = {
                             "url": url,
-                            "provider": website_info["provider"],
-                            "job_id": website_info["job_id"],
-                            "only_main_content": website_info.get("only_main_content", False),
+                            "provider": website_info.provider,
+                            "job_id": website_info.job_id,
+                            "only_main_content": website_info.only_main_content,
                             "mode": "crawl",
                         }
                         if len(url) > 255:
@@ -919,9 +960,9 @@ class DocumentService:
                         document = DocumentService.build_document(
                             dataset,
                             dataset_process_rule.id,
-                            document_data["data_source"]["type"],
-                            document_data["doc_form"],
-                            document_data["doc_language"],
+                            knowledge_config.data_source.info_list.data_source_type,
+                            knowledge_config.doc_form,
+                            knowledge_config.doc_language,
                             data_source_info,
                             created_from,
                             position,
@@ -995,31 +1036,31 @@ class DocumentService:
     @staticmethod
     def update_document_with_dataset_id(
         dataset: Dataset,
-        document_data: dict,
+        document_data: KnowledgeConfig,
         account: Account,
         dataset_process_rule: Optional[DatasetProcessRule] = None,
         created_from: str = "web",
     ):
         DatasetService.check_dataset_model_setting(dataset)
-        document = DocumentService.get_document(dataset.id, document_data["original_document_id"])
+        document = DocumentService.get_document(dataset.id, document_data.original_document_id)
         if document is None:
             raise NotFound("Document not found")
         if document.display_status != "available":
             raise ValueError("Document is not available")
         # save process rule
-        if document_data.get("process_rule"):
-            process_rule = document_data["process_rule"]
-            if process_rule["mode"] == "custom":
+        if document_data.process_rule:
+            process_rule = document_data.process_rule
+            if process_rule.mode in {"custom", "hierarchical"}:
                 dataset_process_rule = DatasetProcessRule(
                     dataset_id=dataset.id,
-                    mode=process_rule["mode"],
-                    rules=json.dumps(process_rule["rules"]),
+                    mode=process_rule.mode,
+                    rules=process_rule.rules.model_dump_json(),
                     created_by=account.id,
                 )
-            elif process_rule["mode"] == "automatic":
+            elif process_rule.mode == "automatic":
                 dataset_process_rule = DatasetProcessRule(
                     dataset_id=dataset.id,
-                    mode=process_rule["mode"],
+                    mode=process_rule.mode,
                     rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
                     created_by=account.id,
                 )
@@ -1028,11 +1069,11 @@ class DocumentService:
                 db.session.commit()
                 document.dataset_process_rule_id = dataset_process_rule.id
         # update document data source
-        if document_data.get("data_source"):
+        if document_data.data_source:
             file_name = ""
             data_source_info = {}
-            if document_data["data_source"]["type"] == "upload_file":
-                upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"]
+            if document_data.data_source.info_list.data_source_type == "upload_file":
+                upload_file_list = document_data.data_source.info_list.file_info_list.file_ids
                 for file_id in upload_file_list:
                     file = (
                         db.session.query(UploadFile)
@@ -1048,10 +1089,10 @@ class DocumentService:
                     data_source_info = {
                         "upload_file_id": file_id,
                     }
-            elif document_data["data_source"]["type"] == "notion_import":
-                notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"]
+            elif document_data.data_source.info_list.data_source_type == "notion_import":
+                notion_info_list = document_data.data_source.info_list.notion_info_list
                 for notion_info in notion_info_list:
-                    workspace_id = notion_info["workspace_id"]
+                    workspace_id = notion_info.workspace_id
                     data_source_binding = DataSourceOauthBinding.query.filter(
                         db.and_(
                             DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
@@ -1062,31 +1103,31 @@ class DocumentService:
                     ).first()
                     if not data_source_binding:
                         raise ValueError("Data source binding not found.")
-                    for page in notion_info["pages"]:
+                    for page in notion_info.pages:
                         data_source_info = {
                             "notion_workspace_id": workspace_id,
-                            "notion_page_id": page["page_id"],
-                            "notion_page_icon": page["page_icon"],
-                            "type": page["type"],
+                            "notion_page_id": page.page_id,
+                            "notion_page_icon": page.page_icon,
+                            "type": page.type,
                         }
-            elif document_data["data_source"]["type"] == "website_crawl":
-                website_info = document_data["data_source"]["info_list"]["website_info_list"]
-                urls = website_info["urls"]
+            elif document_data.data_source.info_list.data_source_type == "website_crawl":
+                website_info = document_data.data_source.info_list.website_info_list
+                urls = website_info.urls
                 for url in urls:
                     data_source_info = {
                         "url": url,
-                        "provider": website_info["provider"],
-                        "job_id": website_info["job_id"],
-                        "only_main_content": website_info.get("only_main_content", False),
+                        "provider": website_info.provider,
+                        "job_id": website_info.job_id,
+                        "only_main_content": website_info.only_main_content,
                         "mode": "crawl",
                     }
-            document.data_source_type = document_data["data_source"]["type"]
+            document.data_source_type = document_data.data_source.info_list.data_source_type
             document.data_source_info = json.dumps(data_source_info)
             document.name = file_name
 
         # update document name
-        if document_data.get("name"):
-            document.name = document_data["name"]
+        if document_data.name:
+            document.name = document_data.name
         # update document to be waiting
         document.indexing_status = "waiting"
         document.completed_at = None
@@ -1096,7 +1137,7 @@ class DocumentService:
         document.splitting_completed_at = None
         document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
         document.created_from = created_from
-        document.doc_form = document_data["doc_form"]
+        document.doc_form = document_data.doc_form
         db.session.add(document)
         db.session.commit()
         # update document segment
@@ -1108,21 +1149,21 @@ class DocumentService:
         return document
 
     @staticmethod
-    def save_document_without_dataset_id(tenant_id: str, document_data: dict, account: Account):
+    def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account):
         features = FeatureService.get_features(current_user.current_tenant_id)
 
         if features.billing.enabled:
             count = 0
-            if document_data["data_source"]["type"] == "upload_file":
-                upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"]
+            if knowledge_config.data_source.info_list.data_source_type == "upload_file":
+                upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
                 count = len(upload_file_list)
-            elif document_data["data_source"]["type"] == "notion_import":
-                notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"]
+            elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
+                notion_info_list = knowledge_config.data_source.info_list.notion_info_list
                 for notion_info in notion_info_list:
-                    count = count + len(notion_info["pages"])
-            elif document_data["data_source"]["type"] == "website_crawl":
-                website_info = document_data["data_source"]["info_list"]["website_info_list"]
-                count = len(website_info["urls"])
+                    count = count + len(notion_info.pages)
+            elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
+                website_info = knowledge_config.data_source.info_list.website_info_list
+                count = len(website_info.urls)
             batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
             if count > batch_upload_limit:
                 raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
@@ -1131,13 +1172,13 @@ class DocumentService:
 
         dataset_collection_binding_id = None
         retrieval_model = None
-        if document_data["indexing_technique"] == "high_quality":
+        if knowledge_config.indexing_technique == "high_quality":
             dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
-                document_data["embedding_model_provider"], document_data["embedding_model"]
+                knowledge_config.embedding_model_provider, knowledge_config.embedding_model
             )
             dataset_collection_binding_id = dataset_collection_binding.id
-            if document_data.get("retrieval_model"):
-                retrieval_model = document_data["retrieval_model"]
+            if knowledge_config.retrieval_model:
+                retrieval_model = knowledge_config.retrieval_model
             else:
                 default_retrieval_model = {
                     "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
@@ -1146,24 +1187,24 @@ class DocumentService:
                     "top_k": 2,
                     "score_threshold_enabled": False,
                 }
-                retrieval_model = default_retrieval_model
+                retrieval_model = RetrievalModel(**default_retrieval_model)
         # save dataset
         dataset = Dataset(
             tenant_id=tenant_id,
             name="",
-            data_source_type=document_data["data_source"]["type"],
-            indexing_technique=document_data.get("indexing_technique", "high_quality"),
+            data_source_type=knowledge_config.data_source.info_list.data_source_type,
+            indexing_technique=knowledge_config.indexing_technique,
             created_by=account.id,
-            embedding_model=document_data.get("embedding_model"),
-            embedding_model_provider=document_data.get("embedding_model_provider"),
+            embedding_model=knowledge_config.embedding_model,
+            embedding_model_provider=knowledge_config.embedding_model_provider,
             collection_binding_id=dataset_collection_binding_id,
-            retrieval_model=retrieval_model,
+            retrieval_model=retrieval_model.model_dump() if retrieval_model else None,
         )
 
-        db.session.add(dataset)
+        db.session.add(dataset)  # type: ignore
         db.session.flush()
 
-        documents, batch = DocumentService.save_document_with_dataset_id(dataset, document_data, account)
+        documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account)
 
         cut_length = 18
         cut_name = documents[0].name[:cut_length]
@@ -1174,133 +1215,86 @@ class DocumentService:
         return dataset, documents, batch
 
     @classmethod
-    def document_create_args_validate(cls, args: dict):
-        if "original_document_id" not in args or not args["original_document_id"]:
-            DocumentService.data_source_args_validate(args)
-            DocumentService.process_rule_args_validate(args)
+    def document_create_args_validate(cls, knowledge_config: KnowledgeConfig):
+        if not knowledge_config.data_source and not knowledge_config.process_rule:
+            raise ValueError("Data source or Process rule is required")
         else:
-            if ("data_source" not in args or not args["data_source"]) and (
-                "process_rule" not in args or not args["process_rule"]
-            ):
-                raise ValueError("Data source or Process rule is required")
-            else:
-                if args.get("data_source"):
-                    DocumentService.data_source_args_validate(args)
-                if args.get("process_rule"):
-                    DocumentService.process_rule_args_validate(args)
+            if knowledge_config.data_source:
+                DocumentService.data_source_args_validate(knowledge_config)
+            if knowledge_config.process_rule:
+                DocumentService.process_rule_args_validate(knowledge_config)
 
     @classmethod
-    def data_source_args_validate(cls, args: dict):
-        if "data_source" not in args or not args["data_source"]:
+    def data_source_args_validate(cls, knowledge_config: KnowledgeConfig):
+        if not knowledge_config.data_source:
             raise ValueError("Data source is required")
 
-        if not isinstance(args["data_source"], dict):
-            raise ValueError("Data source is invalid")
-
-        if "type" not in args["data_source"] or not args["data_source"]["type"]:
-            raise ValueError("Data source type is required")
-
-        if args["data_source"]["type"] not in Document.DATA_SOURCES:
+        if knowledge_config.data_source.info_list.data_source_type not in Document.DATA_SOURCES:
             raise ValueError("Data source type is invalid")
 
-        if "info_list" not in args["data_source"] or not args["data_source"]["info_list"]:
+        if not knowledge_config.data_source.info_list:
             raise ValueError("Data source info is required")
 
-        if args["data_source"]["type"] == "upload_file":
-            if (
-                "file_info_list" not in args["data_source"]["info_list"]
-                or not args["data_source"]["info_list"]["file_info_list"]
-            ):
+        if knowledge_config.data_source.info_list.data_source_type == "upload_file":
+            if not knowledge_config.data_source.info_list.file_info_list:
                 raise ValueError("File source info is required")
-        if args["data_source"]["type"] == "notion_import":
-            if (
-                "notion_info_list" not in args["data_source"]["info_list"]
-                or not args["data_source"]["info_list"]["notion_info_list"]
-            ):
+        if knowledge_config.data_source.info_list.data_source_type == "notion_import":
+            if not knowledge_config.data_source.info_list.notion_info_list:
                 raise ValueError("Notion source info is required")
-        if args["data_source"]["type"] == "website_crawl":
-            if (
-                "website_info_list" not in args["data_source"]["info_list"]
-                or not args["data_source"]["info_list"]["website_info_list"]
-            ):
+        if knowledge_config.data_source.info_list.data_source_type == "website_crawl":
+            if not knowledge_config.data_source.info_list.website_info_list:
                 raise ValueError("Website source info is required")
 
     @classmethod
-    def process_rule_args_validate(cls, args: dict):
-        if "process_rule" not in args or not args["process_rule"]:
+    def process_rule_args_validate(cls, knowledge_config: KnowledgeConfig):
+        if not knowledge_config.process_rule:
             raise ValueError("Process rule is required")
 
-        if not isinstance(args["process_rule"], dict):
-            raise ValueError("Process rule is invalid")
-
-        if "mode" not in args["process_rule"] or not args["process_rule"]["mode"]:
+        if not knowledge_config.process_rule.mode:
             raise ValueError("Process rule mode is required")
 
-        if args["process_rule"]["mode"] not in DatasetProcessRule.MODES:
+        if knowledge_config.process_rule.mode not in DatasetProcessRule.MODES:
             raise ValueError("Process rule mode is invalid")
 
-        if args["process_rule"]["mode"] == "automatic":
-            args["process_rule"]["rules"] = {}
+        if knowledge_config.process_rule.mode == "automatic":
+            knowledge_config.process_rule.rules = None
         else:
-            if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]:
+            if not knowledge_config.process_rule.rules:
                 raise ValueError("Process rule rules is required")
 
-            if not isinstance(args["process_rule"]["rules"], dict):
-                raise ValueError("Process rule rules is invalid")
-
-            if (
-                "pre_processing_rules" not in args["process_rule"]["rules"]
-                or args["process_rule"]["rules"]["pre_processing_rules"] is None
-            ):
+            if knowledge_config.process_rule.rules.pre_processing_rules is None:
                 raise ValueError("Process rule pre_processing_rules is required")
 
-            if not isinstance(args["process_rule"]["rules"]["pre_processing_rules"], list):
-                raise ValueError("Process rule pre_processing_rules is invalid")
-
             unique_pre_processing_rule_dicts = {}
-            for pre_processing_rule in args["process_rule"]["rules"]["pre_processing_rules"]:
-                if "id" not in pre_processing_rule or not pre_processing_rule["id"]:
+            for pre_processing_rule in knowledge_config.process_rule.rules.pre_processing_rules:
+                if not pre_processing_rule.id:
                     raise ValueError("Process rule pre_processing_rules id is required")
 
-                if pre_processing_rule["id"] not in DatasetProcessRule.PRE_PROCESSING_RULES:
-                    raise ValueError("Process rule pre_processing_rules id is invalid")
-
-                if "enabled" not in pre_processing_rule or pre_processing_rule["enabled"] is None:
-                    raise ValueError("Process rule pre_processing_rules enabled is required")
-
-                if not isinstance(pre_processing_rule["enabled"], bool):
+                if not isinstance(pre_processing_rule.enabled, bool):
                     raise ValueError("Process rule pre_processing_rules enabled is invalid")
 
-                unique_pre_processing_rule_dicts[pre_processing_rule["id"]] = pre_processing_rule
+                unique_pre_processing_rule_dicts[pre_processing_rule.id] = pre_processing_rule
 
-            args["process_rule"]["rules"]["pre_processing_rules"] = list(unique_pre_processing_rule_dicts.values())
+            knowledge_config.process_rule.rules.pre_processing_rules = list(unique_pre_processing_rule_dicts.values())
 
-            if (
-                "segmentation" not in args["process_rule"]["rules"]
-                or args["process_rule"]["rules"]["segmentation"] is None
-            ):
+            if not knowledge_config.process_rule.rules.segmentation:
                 raise ValueError("Process rule segmentation is required")
 
-            if not isinstance(args["process_rule"]["rules"]["segmentation"], dict):
-                raise ValueError("Process rule segmentation is invalid")
-
-            if (
-                "separator" not in args["process_rule"]["rules"]["segmentation"]
-                or not args["process_rule"]["rules"]["segmentation"]["separator"]
-            ):
+            if not knowledge_config.process_rule.rules.segmentation.separator:
                 raise ValueError("Process rule segmentation separator is required")
 
-            if not isinstance(args["process_rule"]["rules"]["segmentation"]["separator"], str):
+            if not isinstance(knowledge_config.process_rule.rules.segmentation.separator, str):
                 raise ValueError("Process rule segmentation separator is invalid")
 
-            if (
-                "max_tokens" not in args["process_rule"]["rules"]["segmentation"]
-                or not args["process_rule"]["rules"]["segmentation"]["max_tokens"]
+            if not (
+                knowledge_config.process_rule.mode == "hierarchical"
+                and knowledge_config.process_rule.rules.parent_mode == "full-doc"
             ):
-                raise ValueError("Process rule segmentation max_tokens is required")
+                if not knowledge_config.process_rule.rules.segmentation.max_tokens:
+                    raise ValueError("Process rule segmentation max_tokens is required")
 
-            if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int):
-                raise ValueError("Process rule segmentation max_tokens is invalid")
+                if not isinstance(knowledge_config.process_rule.rules.segmentation.max_tokens, int):
+                    raise ValueError("Process rule segmentation max_tokens is invalid")
 
     @classmethod
     def estimate_args_validate(cls, args: dict):
@@ -1447,7 +1441,7 @@ class SegmentService:
 
             # save vector index
             try:
-                VectorService.create_segments_vector([args["keywords"]], [segment_document], dataset)
+                VectorService.create_segments_vector([args["keywords"]], [segment_document], dataset, document.doc_form)
             except Exception as e:
                 logging.exception("create segment index failed")
                 segment_document.enabled = False
@@ -1525,7 +1519,7 @@ class SegmentService:
             db.session.add(document)
             try:
                 # save vector index
-                VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset)
+                VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset, document.doc_form)
             except Exception as e:
                 logging.exception("create segment index failed")
                 for segment_document in segment_data_list:
@@ -1537,14 +1531,13 @@ class SegmentService:
             return segment_data_list
 
     @classmethod
-    def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset):
-        segment_update_entity = SegmentUpdateEntity(**args)
+    def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset):
         indexing_cache_key = "segment_{}_indexing".format(segment.id)
         cache_result = redis_client.get(indexing_cache_key)
         if cache_result is not None:
             raise ValueError("Segment is indexing, please try again later")
-        if segment_update_entity.enabled is not None:
-            action = segment_update_entity.enabled
+        if args.enabled is not None:
+            action = args.enabled
             if segment.enabled != action:
                 if not action:
                     segment.enabled = action
@@ -1557,22 +1550,22 @@ class SegmentService:
                     disable_segment_from_index_task.delay(segment.id)
                     return segment
         if not segment.enabled:
-            if segment_update_entity.enabled is not None:
-                if not segment_update_entity.enabled:
+            if args.enabled is not None:
+                if not args.enabled:
                     raise ValueError("Can't update disabled segment")
             else:
                 raise ValueError("Can't update disabled segment")
         try:
             word_count_change = segment.word_count
-            content = segment_update_entity.content
+            content = args.content
             if segment.content == content:
                 segment.word_count = len(content)
                 if document.doc_form == "qa_model":
-                    segment.answer = segment_update_entity.answer
-                    segment.word_count += len(segment_update_entity.answer or "")
+                    segment.answer = args.answer
+                    segment.word_count += len(args.answer)
                 word_count_change = segment.word_count - word_count_change
-                if segment_update_entity.keywords:
-                    segment.keywords = segment_update_entity.keywords
+                if args.keywords:
+                    segment.keywords = args.keywords
                 segment.enabled = True
                 segment.disabled_at = None
                 segment.disabled_by = None
@@ -1583,9 +1576,38 @@ class SegmentService:
                     document.word_count = max(0, document.word_count + word_count_change)
                     db.session.add(document)
                 # update segment index task
-                if segment_update_entity.enabled:
-                    keywords = segment_update_entity.keywords or []
-                    VectorService.create_segments_vector([keywords], [segment], dataset)
+                if args.enabled:
+                    VectorService.create_segments_vector([args.keywords], [segment], dataset)
+                if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
+                    # regenerate child chunks
+                    # get embedding model instance
+                    if dataset.indexing_technique == "high_quality":
+                        # check embedding model setting
+                        model_manager = ModelManager()
+
+                        if dataset.embedding_model_provider:
+                            embedding_model_instance = model_manager.get_model_instance(
+                                tenant_id=dataset.tenant_id,
+                                provider=dataset.embedding_model_provider,
+                                model_type=ModelType.TEXT_EMBEDDING,
+                                model=dataset.embedding_model,
+                            )
+                        else:
+                            embedding_model_instance = model_manager.get_default_model_instance(
+                                tenant_id=dataset.tenant_id,
+                                model_type=ModelType.TEXT_EMBEDDING,
+                            )
+                    else:
+                        raise ValueError("The knowledge base index technique is not high quality!")
+                    # get the process rule
+                    processing_rule = (
+                        db.session.query(DatasetProcessRule)
+                        .filter(DatasetProcessRule.id == document.dataset_process_rule_id)
+                        .first()
+                    )
+                    VectorService.generate_child_chunks(
+                        segment, document, dataset, embedding_model_instance, processing_rule, True
+                    )
             else:
                 segment_hash = helper.generate_text_hash(content)
                 tokens = 0
@@ -1616,8 +1638,8 @@ class SegmentService:
                 segment.disabled_at = None
                 segment.disabled_by = None
                 if document.doc_form == "qa_model":
-                    segment.answer = segment_update_entity.answer
-                    segment.word_count += len(segment_update_entity.answer or "")
+                    segment.answer = args.answer
+                    segment.word_count += len(args.answer)
                 word_count_change = segment.word_count - word_count_change
                 # update document word count
                 if word_count_change != 0:
@@ -1625,8 +1647,38 @@ class SegmentService:
                     db.session.add(document)
                 db.session.add(segment)
                 db.session.commit()
-                # update segment vector index
-                VectorService.update_segment_vector(segment_update_entity.keywords, segment, dataset)
+                if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
+                    # get embedding model instance
+                    if dataset.indexing_technique == "high_quality":
+                        # check embedding model setting
+                        model_manager = ModelManager()
+
+                        if dataset.embedding_model_provider:
+                            embedding_model_instance = model_manager.get_model_instance(
+                                tenant_id=dataset.tenant_id,
+                                provider=dataset.embedding_model_provider,
+                                model_type=ModelType.TEXT_EMBEDDING,
+                                model=dataset.embedding_model,
+                            )
+                        else:
+                            embedding_model_instance = model_manager.get_default_model_instance(
+                                tenant_id=dataset.tenant_id,
+                                model_type=ModelType.TEXT_EMBEDDING,
+                            )
+                    else:
+                        raise ValueError("The knowledge base index technique is not high quality!")
+                    # get the process rule
+                    processing_rule = (
+                        db.session.query(DatasetProcessRule)
+                        .filter(DatasetProcessRule.id == document.dataset_process_rule_id)
+                        .first()
+                    )
+                    VectorService.generate_child_chunks(
+                        segment, document, dataset, embedding_model_instance, processing_rule, True
+                    )
+                elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
+                    # update segment vector index
+                    VectorService.update_segment_vector(args.keywords, segment, dataset)
 
         except Exception as e:
             logging.exception("update segment index failed")
@@ -1649,13 +1701,265 @@ class SegmentService:
         if segment.enabled:
             # send delete segment index task
             redis_client.setex(indexing_cache_key, 600, 1)
-            delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id)
+            delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id)
         db.session.delete(segment)
         # update document word count
         document.word_count -= segment.word_count
         db.session.add(document)
         db.session.commit()
 
+    @classmethod
+    def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset):
+        index_node_ids = (
+            DocumentSegment.query.with_entities(DocumentSegment.index_node_id)
+            .filter(
+                DocumentSegment.id.in_(segment_ids),
+                DocumentSegment.dataset_id == dataset.id,
+                DocumentSegment.document_id == document.id,
+                DocumentSegment.tenant_id == current_user.current_tenant_id,
+            )
+            .all()
+        )
+        index_node_ids = [index_node_id[0] for index_node_id in index_node_ids]
+
+        delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id)
+        db.session.query(DocumentSegment).filter(DocumentSegment.id.in_(segment_ids)).delete()
+        db.session.commit()
+
+    @classmethod
+    def update_segments_status(cls, segment_ids: list, action: str, dataset: Dataset, document: Document):
+        if action == "enable":
+            segments = (
+                db.session.query(DocumentSegment)
+                .filter(
+                    DocumentSegment.id.in_(segment_ids),
+                    DocumentSegment.dataset_id == dataset.id,
+                    DocumentSegment.document_id == document.id,
+                    DocumentSegment.enabled == False,
+                )
+                .all()
+            )
+            if not segments:
+                return
+            real_deal_segmment_ids = []
+            for segment in segments:
+                indexing_cache_key = "segment_{}_indexing".format(segment.id)
+                cache_result = redis_client.get(indexing_cache_key)
+                if cache_result is not None:
+                    continue
+                segment.enabled = True
+                segment.disabled_at = None
+                segment.disabled_by = None
+                db.session.add(segment)
+                real_deal_segmment_ids.append(segment.id)
+            db.session.commit()
+
+            enable_segments_to_index_task.delay(real_deal_segmment_ids, dataset.id, document.id)
+        elif action == "disable":
+            segments = (
+                db.session.query(DocumentSegment)
+                .filter(
+                    DocumentSegment.id.in_(segment_ids),
+                    DocumentSegment.dataset_id == dataset.id,
+                    DocumentSegment.document_id == document.id,
+                    DocumentSegment.enabled == True,
+                )
+                .all()
+            )
+            if not segments:
+                return
+            real_deal_segmment_ids = []
+            for segment in segments:
+                indexing_cache_key = "segment_{}_indexing".format(segment.id)
+                cache_result = redis_client.get(indexing_cache_key)
+                if cache_result is not None:
+                    continue
+                segment.enabled = False
+                segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
+                segment.disabled_by = current_user.id
+                db.session.add(segment)
+                real_deal_segmment_ids.append(segment.id)
+            db.session.commit()
+
+            disable_segments_from_index_task.delay(real_deal_segmment_ids, dataset.id, document.id)
+        else:
+            raise InvalidActionError()
+
+    @classmethod
+    def create_child_chunk(
+        cls, content: str, segment: DocumentSegment, document: Document, dataset: Dataset
+    ) -> ChildChunk:
+        lock_name = "add_child_lock_{}".format(segment.id)
+        with redis_client.lock(lock_name, timeout=20):
+            index_node_id = str(uuid.uuid4())
+            index_node_hash = helper.generate_text_hash(content)
+            child_chunk_count = (
+                db.session.query(ChildChunk)
+                .filter(
+                    ChildChunk.tenant_id == current_user.current_tenant_id,
+                    ChildChunk.dataset_id == dataset.id,
+                    ChildChunk.document_id == document.id,
+                    ChildChunk.segment_id == segment.id,
+                )
+                .count()
+            )
+            max_position = (
+                db.session.query(func.max(ChildChunk.position))
+                .filter(
+                    ChildChunk.tenant_id == current_user.current_tenant_id,
+                    ChildChunk.dataset_id == dataset.id,
+                    ChildChunk.document_id == document.id,
+                    ChildChunk.segment_id == segment.id,
+                )
+                .scalar()
+            )
+            child_chunk = ChildChunk(
+                tenant_id=current_user.current_tenant_id,
+                dataset_id=dataset.id,
+                document_id=document.id,
+                segment_id=segment.id,
+                position=max_position + 1,
+                index_node_id=index_node_id,
+                index_node_hash=index_node_hash,
+                content=content,
+                word_count=len(content),
+                type="customized",
+                created_by=current_user.id,
+            )
+            db.session.add(child_chunk)
+            # save vector index
+            try:
+                VectorService.create_child_chunk_vector(child_chunk, dataset)
+            except Exception as e:
+                logging.exception("create child chunk index failed")
+                db.session.rollback()
+                raise ChildChunkIndexingError(str(e))
+            db.session.commit()
+
+            return child_chunk
+
+    @classmethod
+    def update_child_chunks(
+        cls,
+        child_chunks_update_args: list[ChildChunkUpdateArgs],
+        segment: DocumentSegment,
+        document: Document,
+        dataset: Dataset,
+    ) -> list[ChildChunk]:
+        child_chunks = (
+            db.session.query(ChildChunk)
+            .filter(
+                ChildChunk.dataset_id == dataset.id,
+                ChildChunk.document_id == document.id,
+                ChildChunk.segment_id == segment.id,
+            )
+            .all()
+        )
+        child_chunks_map = {chunk.id: chunk for chunk in child_chunks}
+
+        new_child_chunks, update_child_chunks, delete_child_chunks, new_child_chunks_args = [], [], [], []
+
+        for child_chunk_update_args in child_chunks_update_args:
+            if child_chunk_update_args.id:
+                child_chunk = child_chunks_map.pop(child_chunk_update_args.id, None)
+                if child_chunk:
+                    if child_chunk.content != child_chunk_update_args.content:
+                        child_chunk.content = child_chunk_update_args.content
+                        child_chunk.word_count = len(child_chunk.content)
+                        child_chunk.updated_by = current_user.id
+                        child_chunk.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
+                        child_chunk.type = "customized"
+                        update_child_chunks.append(child_chunk)
+            else:
+                new_child_chunks_args.append(child_chunk_update_args)
+        if child_chunks_map:
+            delete_child_chunks = list(child_chunks_map.values())
+        try:
+            if update_child_chunks:
+                db.session.bulk_save_objects(update_child_chunks)
+
+            if delete_child_chunks:
+                for child_chunk in delete_child_chunks:
+                    db.session.delete(child_chunk)
+            if new_child_chunks_args:
+                child_chunk_count = len(child_chunks)
+                for position, args in enumerate(new_child_chunks_args, start=child_chunk_count + 1):
+                    index_node_id = str(uuid.uuid4())
+                    index_node_hash = helper.generate_text_hash(args.content)
+                    child_chunk = ChildChunk(
+                        tenant_id=current_user.current_tenant_id,
+                        dataset_id=dataset.id,
+                        document_id=document.id,
+                        segment_id=segment.id,
+                        position=position,
+                        index_node_id=index_node_id,
+                        index_node_hash=index_node_hash,
+                        content=args.content,
+                        word_count=len(args.content),
+                        type="customized",
+                        created_by=current_user.id,
+                    )
+
+                    db.session.add(child_chunk)
+                    db.session.flush()
+                    new_child_chunks.append(child_chunk)
+            VectorService.update_child_chunk_vector(new_child_chunks, update_child_chunks, delete_child_chunks, dataset)
+            db.session.commit()
+        except Exception as e:
+            logging.exception("update child chunk index failed")
+            db.session.rollback()
+            raise ChildChunkIndexingError(str(e))
+        return sorted(new_child_chunks + update_child_chunks, key=lambda x: x.position)
+
+    @classmethod
+    def update_child_chunk(
+        cls,
+        content: str,
+        child_chunk: ChildChunk,
+        segment: DocumentSegment,
+        document: Document,
+        dataset: Dataset,
+    ) -> ChildChunk:
+        try:
+            child_chunk.content = content
+            child_chunk.word_count = len(content)
+            child_chunk.updated_by = current_user.id
+            child_chunk.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
+            child_chunk.type = "customized"
+            db.session.add(child_chunk)
+            VectorService.update_child_chunk_vector([], [child_chunk], [], dataset)
+            db.session.commit()
+        except Exception as e:
+            logging.exception("update child chunk index failed")
+            db.session.rollback()
+            raise ChildChunkIndexingError(str(e))
+        return child_chunk
+
+    @classmethod
+    def delete_child_chunk(cls, child_chunk: ChildChunk, dataset: Dataset):
+        db.session.delete(child_chunk)
+        try:
+            VectorService.delete_child_chunk_vector(child_chunk, dataset)
+        except Exception as e:
+            logging.exception("delete child chunk index failed")
+            db.session.rollback()
+            raise ChildChunkDeleteIndexError(str(e))
+        db.session.commit()
+
+    @classmethod
+    def get_child_chunks(
+        cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None
+    ):
+        query = ChildChunk.query.filter_by(
+            tenant_id=current_user.current_tenant_id,
+            dataset_id=dataset_id,
+            document_id=document_id,
+            segment_id=segment_id,
+        ).order_by(ChildChunk.position.asc())
+        if keyword:
+            query = query.where(ChildChunk.content.ilike(f"%{keyword}%"))
+        return query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
+
 
 class DatasetCollectionBindingService:
     @classmethod

+ 111 - 1
api/services/entities/knowledge_entities/knowledge_entities.py

@@ -1,4 +1,5 @@
-from typing import Optional
+from enum import Enum
+from typing import Literal, Optional
 
 from pydantic import BaseModel
 
@@ -8,3 +9,112 @@ class SegmentUpdateEntity(BaseModel):
     answer: Optional[str] = None
     keywords: Optional[list[str]] = None
     enabled: Optional[bool] = None
+
+
+class ParentMode(str, Enum):
+    FULL_DOC = "full-doc"
+    PARAGRAPH = "paragraph"
+
+
+class NotionIcon(BaseModel):
+    type: str
+    url: Optional[str] = None
+    emoji: Optional[str] = None
+
+
+class NotionPage(BaseModel):
+    page_id: str
+    page_name: str
+    page_icon: Optional[NotionIcon] = None
+    type: str
+
+
+class NotionInfo(BaseModel):
+    workspace_id: str
+    pages: list[NotionPage]
+
+
+class WebsiteInfo(BaseModel):
+    provider: str
+    job_id: str
+    urls: list[str]
+    only_main_content: bool = True
+
+
+class FileInfo(BaseModel):
+    file_ids: list[str]
+
+
+class InfoList(BaseModel):
+    data_source_type: Literal["upload_file", "notion_import", "website_crawl"]
+    notion_info_list: Optional[list[NotionInfo]] = None
+    file_info_list: Optional[FileInfo] = None
+    website_info_list: Optional[WebsiteInfo] = None
+
+
+class DataSource(BaseModel):
+    info_list: InfoList
+
+
+class PreProcessingRule(BaseModel):
+    id: str
+    enabled: bool
+
+
+class Segmentation(BaseModel):
+    separator: str = "\n"
+    max_tokens: int
+    chunk_overlap: int = 0
+
+
+class Rule(BaseModel):
+    pre_processing_rules: Optional[list[PreProcessingRule]] = None
+    segmentation: Optional[Segmentation] = None
+    parent_mode: Optional[Literal["full-doc", "paragraph"]] = None
+    subchunk_segmentation: Optional[Segmentation] = None
+
+
+class ProcessRule(BaseModel):
+    mode: Literal["automatic", "custom", "hierarchical"]
+    rules: Optional[Rule] = None
+
+
+class RerankingModel(BaseModel):
+    reranking_provider_name: Optional[str] = None
+    reranking_model_name: Optional[str] = None
+
+
+class RetrievalModel(BaseModel):
+    search_method: Literal["hybrid_search", "semantic_search", "full_text_search"]
+    reranking_enable: bool
+    reranking_model: Optional[RerankingModel] = None
+    top_k: int
+    score_threshold_enabled: bool
+    score_threshold: Optional[float] = None
+
+
+class KnowledgeConfig(BaseModel):
+    original_document_id: Optional[str] = None
+    duplicate: bool = True
+    indexing_technique: Literal["high_quality", "economy"]
+    data_source: Optional[DataSource] = None
+    process_rule: Optional[ProcessRule] = None
+    retrieval_model: Optional[RetrievalModel] = None
+    doc_form: str = "text_model"
+    doc_language: str = "English"
+    embedding_model: Optional[str] = None
+    embedding_model_provider: Optional[str] = None
+    name: Optional[str] = None
+
+
+class SegmentUpdateArgs(BaseModel):
+    content: Optional[str] = None
+    answer: Optional[str] = None
+    keywords: Optional[list[str]] = None
+    regenerate_child_chunks: bool = False
+    enabled: Optional[bool] = None
+
+
+class ChildChunkUpdateArgs(BaseModel):
+    id: Optional[str] = None
+    content: str

+ 9 - 0
api/services/errors/chunk.py

@@ -0,0 +1,9 @@
+from services.errors.base import BaseServiceError
+
+
+class ChildChunkIndexingError(BaseServiceError):
+    description = "{message}"
+
+
+class ChildChunkDeleteIndexError(BaseServiceError):
+    description = "{message}"

+ 5 - 32
api/services/hit_testing_service.py

@@ -7,7 +7,7 @@ from core.rag.models.document import Document
 from core.rag.retrieval.retrieval_methods import RetrievalMethod
 from extensions.ext_database import db
 from models.account import Account
-from models.dataset import Dataset, DatasetQuery, DocumentSegment
+from models.dataset import Dataset, DatasetQuery
 
 default_retrieval_model = {
     "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
@@ -69,7 +69,7 @@ class HitTestingService:
         db.session.add(dataset_query)
         db.session.commit()
 
-        return dict(cls.compact_retrieve_response(dataset, query, all_documents))
+        return cls.compact_retrieve_response(query, all_documents)
 
     @classmethod
     def external_retrieve(
@@ -106,41 +106,14 @@ class HitTestingService:
         return dict(cls.compact_external_retrieve_response(dataset, query, all_documents))
 
     @classmethod
-    def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]):
-        records = []
-
-        for document in documents:
-            if document.metadata is None:
-                continue
-
-            index_node_id = document.metadata["doc_id"]
-
-            segment = (
-                db.session.query(DocumentSegment)
-                .filter(
-                    DocumentSegment.dataset_id == dataset.id,
-                    DocumentSegment.enabled == True,
-                    DocumentSegment.status == "completed",
-                    DocumentSegment.index_node_id == index_node_id,
-                )
-                .first()
-            )
-
-            if not segment:
-                continue
-
-            record = {
-                "segment": segment,
-                "score": document.metadata.get("score", None),
-            }
-
-            records.append(record)
+    def compact_retrieve_response(cls, query: str, documents: list[Document]):
+        records = RetrievalService.format_retrieval_documents(documents)
 
         return {
             "query": {
                 "content": query,
             },
-            "records": records,
+            "records": [record.model_dump() for record in records],
         }
 
     @classmethod

+ 171 - 23
api/services/vector_service.py

@@ -1,40 +1,68 @@
 from typing import Optional
 
+from core.model_manager import ModelInstance, ModelManager
+from core.model_runtime.entities.model_entities import ModelType
 from core.rag.datasource.keyword.keyword_factory import Keyword
 from core.rag.datasource.vdb.vector_factory import Vector
+from core.rag.index_processor.constant.index_type import IndexType
+from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
 from core.rag.models.document import Document
-from models.dataset import Dataset, DocumentSegment
+from extensions.ext_database import db
+from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
+from models.dataset import Document as DatasetDocument
+from services.entities.knowledge_entities.knowledge_entities import ParentMode
 
 
 class VectorService:
     @classmethod
     def create_segments_vector(
-        cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset
+        cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str
     ):
         documents = []
-        for segment in segments:
-            document = Document(
-                page_content=segment.content,
-                metadata={
-                    "doc_id": segment.index_node_id,
-                    "doc_hash": segment.index_node_hash,
-                    "document_id": segment.document_id,
-                    "dataset_id": segment.dataset_id,
-                },
-            )
-            documents.append(document)
-        if dataset.indexing_technique == "high_quality":
-            # save vector index
-            vector = Vector(dataset=dataset)
-            vector.add_texts(documents, duplicate_check=True)
 
-        # save keyword index
-        keyword = Keyword(dataset)
+        for segment in segments:
+            if doc_form == IndexType.PARENT_CHILD_INDEX:
+                document = DatasetDocument.query.filter_by(id=segment.document_id).first()
+                # get the process rule
+                processing_rule = (
+                    db.session.query(DatasetProcessRule)
+                    .filter(DatasetProcessRule.id == document.dataset_process_rule_id)
+                    .first()
+                )
+                # get embedding model instance
+                if dataset.indexing_technique == "high_quality":
+                    # check embedding model setting
+                    model_manager = ModelManager()
 
-        if keywords_list and len(keywords_list) > 0:
-            keyword.add_texts(documents, keywords_list=keywords_list)
-        else:
-            keyword.add_texts(documents)
+                    if dataset.embedding_model_provider:
+                        embedding_model_instance = model_manager.get_model_instance(
+                            tenant_id=dataset.tenant_id,
+                            provider=dataset.embedding_model_provider,
+                            model_type=ModelType.TEXT_EMBEDDING,
+                            model=dataset.embedding_model,
+                        )
+                    else:
+                        embedding_model_instance = model_manager.get_default_model_instance(
+                            tenant_id=dataset.tenant_id,
+                            model_type=ModelType.TEXT_EMBEDDING,
+                        )
+                else:
+                    raise ValueError("The knowledge base index technique is not high quality!")
+                cls.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, False)
+            else:
+                document = Document(
+                    page_content=segment.content,
+                    metadata={
+                        "doc_id": segment.index_node_id,
+                        "doc_hash": segment.index_node_hash,
+                        "document_id": segment.document_id,
+                        "dataset_id": segment.dataset_id,
+                    },
+                )
+                documents.append(document)
+        if len(documents) > 0:
+            index_processor = IndexProcessorFactory(doc_form).init_index_processor()
+            index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list)
 
     @classmethod
     def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset):
@@ -65,3 +93,123 @@ class VectorService:
             keyword.add_texts([document], keywords_list=[keywords])
         else:
             keyword.add_texts([document])
+
+    @classmethod
+    def generate_child_chunks(
+        cls,
+        segment: DocumentSegment,
+        dataset_document: Document,
+        dataset: Dataset,
+        embedding_model_instance: ModelInstance,
+        processing_rule: DatasetProcessRule,
+        regenerate: bool = False,
+    ):
+        index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
+        if regenerate:
+            # delete child chunks
+            index_processor.clean(dataset, [segment.index_node_id], with_keywords=True, delete_child_chunks=True)
+
+        # generate child chunks
+        document = Document(
+            page_content=segment.content,
+            metadata={
+                "doc_id": segment.index_node_id,
+                "doc_hash": segment.index_node_hash,
+                "document_id": segment.document_id,
+                "dataset_id": segment.dataset_id,
+            },
+        )
+        # use full doc mode to generate segment's child chunk
+        processing_rule_dict = processing_rule.to_dict()
+        processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC.value
+        documents = index_processor.transform(
+            [document],
+            embedding_model_instance=embedding_model_instance,
+            process_rule=processing_rule_dict,
+            tenant_id=dataset.tenant_id,
+            doc_language=dataset_document.doc_language,
+        )
+        # save child chunks
+        if len(documents) > 0 and len(documents[0].children) > 0:
+            index_processor.load(dataset, documents)
+
+            for position, child_chunk in enumerate(documents[0].children, start=1):
+                child_segment = ChildChunk(
+                    tenant_id=dataset.tenant_id,
+                    dataset_id=dataset.id,
+                    document_id=dataset_document.id,
+                    segment_id=segment.id,
+                    position=position,
+                    index_node_id=child_chunk.metadata["doc_id"],
+                    index_node_hash=child_chunk.metadata["doc_hash"],
+                    content=child_chunk.page_content,
+                    word_count=len(child_chunk.page_content),
+                    type="automatic",
+                    created_by=dataset_document.created_by,
+                )
+                db.session.add(child_segment)
+        db.session.commit()
+
+    @classmethod
+    def create_child_chunk_vector(cls, child_segment: ChildChunk, dataset: Dataset):
+        child_document = Document(
+            page_content=child_segment.content,
+            metadata={
+                "doc_id": child_segment.index_node_id,
+                "doc_hash": child_segment.index_node_hash,
+                "document_id": child_segment.document_id,
+                "dataset_id": child_segment.dataset_id,
+            },
+        )
+        if dataset.indexing_technique == "high_quality":
+            # save vector index
+            vector = Vector(dataset=dataset)
+            vector.add_texts([child_document], duplicate_check=True)
+
+    @classmethod
+    def update_child_chunk_vector(
+        cls,
+        new_child_chunks: list[ChildChunk],
+        update_child_chunks: list[ChildChunk],
+        delete_child_chunks: list[ChildChunk],
+        dataset: Dataset,
+    ):
+        documents = []
+        delete_node_ids = []
+        for new_child_chunk in new_child_chunks:
+            new_child_document = Document(
+                page_content=new_child_chunk.content,
+                metadata={
+                    "doc_id": new_child_chunk.index_node_id,
+                    "doc_hash": new_child_chunk.index_node_hash,
+                    "document_id": new_child_chunk.document_id,
+                    "dataset_id": new_child_chunk.dataset_id,
+                },
+            )
+            documents.append(new_child_document)
+        for update_child_chunk in update_child_chunks:
+            child_document = Document(
+                page_content=update_child_chunk.content,
+                metadata={
+                    "doc_id": update_child_chunk.index_node_id,
+                    "doc_hash": update_child_chunk.index_node_hash,
+                    "document_id": update_child_chunk.document_id,
+                    "dataset_id": update_child_chunk.dataset_id,
+                },
+            )
+            documents.append(child_document)
+            delete_node_ids.append(update_child_chunk.index_node_id)
+        for delete_child_chunk in delete_child_chunks:
+            delete_node_ids.append(delete_child_chunk.index_node_id)
+        if dataset.indexing_technique == "high_quality":
+            # update vector index
+            vector = Vector(dataset=dataset)
+            if delete_node_ids:
+                vector.delete_by_ids(delete_node_ids)
+            if documents:
+                vector.add_texts(documents, duplicate_check=True)
+
+    @classmethod
+    def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset):
+        vector = Vector(dataset=dataset)
+        vector.delete_by_ids([child_chunk.index_node_id])

+ 25 - 3
api/tasks/add_document_to_index_task.py

@@ -6,12 +6,13 @@ import click
 from celery import shared_task  # type: ignore
 from werkzeug.exceptions import NotFound
 
+from core.rag.index_processor.constant.index_type import IndexType
 from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
-from core.rag.models.document import Document
+from core.rag.models.document import ChildDocument, Document
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
+from models.dataset import DatasetAutoDisableLog, DocumentSegment
 from models.dataset import Document as DatasetDocument
-from models.dataset import DocumentSegment
 
 
 @shared_task(queue="dataset")
@@ -53,7 +54,22 @@ def add_document_to_index_task(dataset_document_id: str):
                     "dataset_id": segment.dataset_id,
                 },
             )
-
+            if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
+                child_chunks = segment.child_chunks
+                if child_chunks:
+                    child_documents = []
+                    for child_chunk in child_chunks:
+                        child_document = ChildDocument(
+                            page_content=child_chunk.content,
+                            metadata={
+                                "doc_id": child_chunk.index_node_id,
+                                "doc_hash": child_chunk.index_node_hash,
+                                "document_id": segment.document_id,
+                                "dataset_id": segment.dataset_id,
+                            },
+                        )
+                        child_documents.append(child_document)
+                    document.children = child_documents
             documents.append(document)
 
         dataset = dataset_document.dataset
@@ -65,6 +81,12 @@ def add_document_to_index_task(dataset_document_id: str):
         index_processor = IndexProcessorFactory(index_type).init_index_processor()
         index_processor.load(dataset, documents)
 
+        # delete auto disable log
+        db.session.query(DatasetAutoDisableLog).filter(
+            DatasetAutoDisableLog.document_id == dataset_document.id
+        ).delete()
+        db.session.commit()
+
         end_at = time.perf_counter()
         logging.info(
             click.style(

+ 75 - 0
api/tasks/batch_clean_document_task.py

@@ -0,0 +1,75 @@
+import logging
+import time
+
+import click
+from celery import shared_task
+
+from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
+from core.tools.utils.web_reader_tool import get_image_upload_file_ids
+from extensions.ext_database import db
+from extensions.ext_storage import storage
+from models.dataset import Dataset, DocumentSegment
+from models.model import UploadFile
+
+
+@shared_task(queue="dataset")
+def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str, file_ids: list[str]):
+    """
+    Clean document when document deleted.
+    :param document_ids: document ids
+    :param dataset_id: dataset id
+    :param doc_form: doc_form
+    :param file_ids: file ids
+
+    Usage: clean_document_task.delay(document_id, dataset_id)
+    """
+    logging.info(click.style("Start batch clean documents when documents deleted", fg="green"))
+    start_at = time.perf_counter()
+
+    try:
+        dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+
+        if not dataset:
+            raise Exception("Document has no dataset")
+
+        segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id.in_(document_ids)).all()
+        # check segment is exist
+        if segments:
+            index_node_ids = [segment.index_node_id for segment in segments]
+            index_processor = IndexProcessorFactory(doc_form).init_index_processor()
+            index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
+
+            for segment in segments:
+                image_upload_file_ids = get_image_upload_file_ids(segment.content)
+                for upload_file_id in image_upload_file_ids:
+                    image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
+                    try:
+                        storage.delete(image_file.key)
+                    except Exception:
+                        logging.exception(
+                            "Delete image_files failed when storage deleted, \
+                                          image_upload_file_is: {}".format(upload_file_id)
+                        )
+                    db.session.delete(image_file)
+                db.session.delete(segment)
+
+            db.session.commit()
+        if file_ids:
+            files = db.session.query(UploadFile).filter(UploadFile.id.in_(file_ids)).all()
+            for file in files:
+                try:
+                    storage.delete(file.key)
+                except Exception:
+                    logging.exception("Delete file failed when document deleted, file_id: {}".format(file.id))
+                db.session.delete(file)
+            db.session.commit()
+
+        end_at = time.perf_counter()
+        logging.info(
+            click.style(
+                "Cleaned documents when documents deleted latency: {}".format(end_at - start_at),
+                fg="green",
+            )
+        )
+    except Exception:
+        logging.exception("Cleaned documents when documents deleted failed")

+ 2 - 3
api/tasks/batch_create_segment_to_index_task.py

@@ -7,13 +7,13 @@ import click
 from celery import shared_task  # type: ignore
 from sqlalchemy import func
 
-from core.indexing_runner import IndexingRunner
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from libs import helper
 from models.dataset import Dataset, Document, DocumentSegment
+from services.vector_service import VectorService
 
 
 @shared_task(queue="dataset")
@@ -96,8 +96,7 @@ def batch_create_segment_to_index_task(
         dataset_document.word_count += word_count_change
         db.session.add(dataset_document)
         # add index to db
-        indexing_runner = IndexingRunner()
-        indexing_runner.batch_add_segments(document_segments, dataset)
+        VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
         db.session.commit()
         redis_client.setex(indexing_cache_key, 600, "completed")
         end_at = time.perf_counter()

+ 1 - 1
api/tasks/clean_dataset_task.py

@@ -62,7 +62,7 @@ def clean_dataset_task(
             if doc_form is None:
                 raise ValueError("Index type must be specified.")
             index_processor = IndexProcessorFactory(doc_form).init_index_processor()
-            index_processor.clean(dataset, None)
+            index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)
 
             for document in documents:
                 db.session.delete(document)

+ 1 - 1
api/tasks/clean_document_task.py

@@ -38,7 +38,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
         if segments:
             index_node_ids = [segment.index_node_id for segment in segments]
             index_processor = IndexProcessorFactory(doc_form).init_index_processor()
-            index_processor.clean(dataset, index_node_ids)
+            index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
 
             for segment in segments:
                 image_upload_file_ids = get_image_upload_file_ids(segment.content)

+ 1 - 1
api/tasks/clean_notion_document_task.py

@@ -37,7 +37,7 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
             segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
             index_node_ids = [segment.index_node_id for segment in segments]
 
-            index_processor.clean(dataset, index_node_ids)
+            index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
 
             for segment in segments:
                 db.session.delete(segment)

+ 19 - 3
api/tasks/deal_dataset_vector_index_task.py

@@ -4,8 +4,9 @@ import time
 import click
 from celery import shared_task  # type: ignore
 
+from core.rag.index_processor.constant.index_type import IndexType
 from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
-from core.rag.models.document import Document
+from core.rag.models.document import ChildDocument, Document
 from extensions.ext_database import db
 from models.dataset import Dataset, DocumentSegment
 from models.dataset import Document as DatasetDocument
@@ -105,7 +106,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
                 db.session.commit()
 
                 # clean index
-                index_processor.clean(dataset, None, with_keywords=False)
+                index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
 
                 for dataset_document in dataset_documents:
                     # update from vector index
@@ -128,7 +129,22 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
                                         "dataset_id": segment.dataset_id,
                                     },
                                 )
-
+                                if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
+                                    child_chunks = segment.child_chunks
+                                    if child_chunks:
+                                        child_documents = []
+                                        for child_chunk in child_chunks:
+                                            child_document = ChildDocument(
+                                                page_content=child_chunk.content,
+                                                metadata={
+                                                    "doc_id": child_chunk.index_node_id,
+                                                    "doc_hash": child_chunk.index_node_hash,
+                                                    "document_id": segment.document_id,
+                                                    "dataset_id": segment.dataset_id,
+                                                },
+                                            )
+                                            child_documents.append(child_document)
+                                        document.children = child_documents
                                 documents.append(document)
                             # save vector index
                             index_processor.load(dataset, documents, with_keywords=False)

+ 6 - 16
api/tasks/delete_segment_from_index_task.py

@@ -6,48 +6,38 @@ from celery import shared_task  # type: ignore
 
 from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
 from extensions.ext_database import db
-from extensions.ext_redis import redis_client
 from models.dataset import Dataset, Document
 
 
 @shared_task(queue="dataset")
-def delete_segment_from_index_task(segment_id: str, index_node_id: str, dataset_id: str, document_id: str):
+def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, document_id: str):
     """
     Async Remove segment from index
-    :param segment_id:
-    :param index_node_id:
+    :param index_node_ids:
     :param dataset_id:
     :param document_id:
 
-    Usage: delete_segment_from_index_task.delay(segment_id)
+    Usage: delete_segment_from_index_task.delay(segment_ids)
     """
-    logging.info(click.style("Start delete segment from index: {}".format(segment_id), fg="green"))
+    logging.info(click.style("Start delete segment from index", fg="green"))
     start_at = time.perf_counter()
-    indexing_cache_key = "segment_{}_delete_indexing".format(segment_id)
     try:
         dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
         if not dataset:
-            logging.info(click.style("Segment {} has no dataset, pass.".format(segment_id), fg="cyan"))
             return
 
         dataset_document = db.session.query(Document).filter(Document.id == document_id).first()
         if not dataset_document:
-            logging.info(click.style("Segment {} has no document, pass.".format(segment_id), fg="cyan"))
             return
 
         if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
-            logging.info(click.style("Segment {} document status is invalid, pass.".format(segment_id), fg="cyan"))
             return
 
         index_type = dataset_document.doc_form
         index_processor = IndexProcessorFactory(index_type).init_index_processor()
-        index_processor.clean(dataset, [index_node_id])
+        index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
 
         end_at = time.perf_counter()
-        logging.info(
-            click.style("Segment deleted from index: {} latency: {}".format(segment_id, end_at - start_at), fg="green")
-        )
+        logging.info(click.style("Segment deleted from index latency: {}".format(end_at - start_at), fg="green"))
     except Exception:
         logging.exception("delete segment from index failed")
-    finally:
-        redis_client.delete(indexing_cache_key)

+ 76 - 0
api/tasks/disable_segments_from_index_task.py

@@ -0,0 +1,76 @@
+import logging
+import time
+
+import click
+from celery import shared_task
+
+from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
+from extensions.ext_database import db
+from extensions.ext_redis import redis_client
+from models.dataset import Dataset, DocumentSegment
+from models.dataset import Document as DatasetDocument
+
+
+@shared_task(queue="dataset")
+def disable_segments_from_index_task(segment_ids: list, dataset_id: str, document_id: str):
+    """
+    Async disable segments from index
+    :param segment_ids:
+
+    Usage: disable_segments_from_index_task.delay(segment_ids, dataset_id, document_id)
+    """
+    start_at = time.perf_counter()
+
+    dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+    if not dataset:
+        logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan"))
+        return
+
+    dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
+
+    if not dataset_document:
+        logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan"))
+        return
+    if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
+        logging.info(click.style("Document {} status is invalid, pass.".format(document_id), fg="cyan"))
+        return
+    # sync index processor
+    index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
+
+    segments = (
+        db.session.query(DocumentSegment)
+        .filter(
+            DocumentSegment.id.in_(segment_ids),
+            DocumentSegment.dataset_id == dataset_id,
+            DocumentSegment.document_id == document_id,
+        )
+        .all()
+    )
+
+    if not segments:
+        return
+
+    try:
+        index_node_ids = [segment.index_node_id for segment in segments]
+        index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
+
+        end_at = time.perf_counter()
+        logging.info(click.style("Segments removed from index latency: {}".format(end_at - start_at), fg="green"))
+    except Exception:
+        # update segment error msg
+        db.session.query(DocumentSegment).filter(
+            DocumentSegment.id.in_(segment_ids),
+            DocumentSegment.dataset_id == dataset_id,
+            DocumentSegment.document_id == document_id,
+        ).update(
+            {
+                "disabled_at": None,
+                "disabled_by": None,
+                "enabled": True,
+            }
+        )
+        db.session.commit()
+    finally:
+        for segment in segments:
+            indexing_cache_key = "segment_{}_indexing".format(segment.id)
+            redis_client.delete(indexing_cache_key)

+ 1 - 1
api/tasks/document_indexing_sync_task.py

@@ -82,7 +82,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
                 index_node_ids = [segment.index_node_id for segment in segments]
 
                 # delete from vector index
-                index_processor.clean(dataset, index_node_ids)
+                index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
 
                 for segment in segments:
                     db.session.delete(segment)

+ 1 - 1
api/tasks/document_indexing_update_task.py

@@ -47,7 +47,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
             index_node_ids = [segment.index_node_id for segment in segments]
 
             # delete from vector index
-            index_processor.clean(dataset, index_node_ids)
+            index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
 
             for segment in segments:
                 db.session.delete(segment)

+ 3 - 3
api/tasks/duplicate_document_indexing_task.py

@@ -51,7 +51,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
             if document:
                 document.indexing_status = "error"
                 document.error = str(e)
-                document.stopped_at = datetime.datetime.utcnow()
+                document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
                 db.session.add(document)
         db.session.commit()
         return
@@ -73,14 +73,14 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
                 index_node_ids = [segment.index_node_id for segment in segments]
 
                 # delete from vector index
-                index_processor.clean(dataset, index_node_ids)
+                index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
 
                 for segment in segments:
                     db.session.delete(segment)
                 db.session.commit()
 
             document.indexing_status = "parsing"
-            document.processing_started_at = datetime.datetime.utcnow()
+            document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
             documents.append(document)
             db.session.add(document)
     db.session.commit()

+ 18 - 1
api/tasks/enable_segment_to_index_task.py

@@ -6,8 +6,9 @@ import click
 from celery import shared_task  # type: ignore
 from werkzeug.exceptions import NotFound
 
+from core.rag.index_processor.constant.index_type import IndexType
 from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
-from core.rag.models.document import Document
+from core.rag.models.document import ChildDocument, Document
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from models.dataset import DocumentSegment
@@ -61,6 +62,22 @@ def enable_segment_to_index_task(segment_id: str):
             return
 
         index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
+        if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
+            child_chunks = segment.child_chunks
+            if child_chunks:
+                child_documents = []
+                for child_chunk in child_chunks:
+                    child_document = ChildDocument(
+                        page_content=child_chunk.content,
+                        metadata={
+                            "doc_id": child_chunk.index_node_id,
+                            "doc_hash": child_chunk.index_node_hash,
+                            "document_id": segment.document_id,
+                            "dataset_id": segment.dataset_id,
+                        },
+                    )
+                    child_documents.append(child_document)
+                document.children = child_documents
         # save vector index
         index_processor.load(dataset, [document])
 

+ 108 - 0
api/tasks/enable_segments_to_index_task.py

@@ -0,0 +1,108 @@
+import datetime
+import logging
+import time
+
+import click
+from celery import shared_task
+
+from core.rag.index_processor.constant.index_type import IndexType
+from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
+from core.rag.models.document import ChildDocument, Document
+from extensions.ext_database import db
+from extensions.ext_redis import redis_client
+from models.dataset import Dataset, DocumentSegment
+from models.dataset import Document as DatasetDocument
+
+
+@shared_task(queue="dataset")
+def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_id: str):
+    """
+    Async enable segments to index
+    :param segment_ids:
+
+    Usage: enable_segments_to_index_task.delay(segment_ids)
+    """
+    start_at = time.perf_counter()
+    dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+    if not dataset:
+        logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan"))
+        return
+
+    dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
+
+    if not dataset_document:
+        logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan"))
+        return
+    if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
+        logging.info(click.style("Document {} status is invalid, pass.".format(document_id), fg="cyan"))
+        return
+    # sync index processor
+    index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
+
+    segments = (
+        db.session.query(DocumentSegment)
+        .filter(
+            DocumentSegment.id.in_(segment_ids),
+            DocumentSegment.dataset_id == dataset_id,
+            DocumentSegment.document_id == document_id,
+        )
+        .all()
+    )
+    if not segments:
+        return
+
+    try:
+        documents = []
+        for segment in segments:
+            document = Document(
+                page_content=segment.content,
+                metadata={
+                    "doc_id": segment.index_node_id,
+                    "doc_hash": segment.index_node_hash,
+                    "document_id": document_id,
+                    "dataset_id": dataset_id,
+                },
+            )
+
+            if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
+                child_chunks = segment.child_chunks
+                if child_chunks:
+                    child_documents = []
+                    for child_chunk in child_chunks:
+                        child_document = ChildDocument(
+                            page_content=child_chunk.content,
+                            metadata={
+                                "doc_id": child_chunk.index_node_id,
+                                "doc_hash": child_chunk.index_node_hash,
+                                "document_id": document_id,
+                                "dataset_id": dataset_id,
+                            },
+                        )
+                        child_documents.append(child_document)
+                    document.children = child_documents
+            documents.append(document)
+        # save vector index
+        index_processor.load(dataset, documents)
+
+        end_at = time.perf_counter()
+        logging.info(click.style("Segments enabled to index latency: {}".format(end_at - start_at), fg="green"))
+    except Exception as e:
+        logging.exception("enable segments to index failed")
+        # update segment error msg
+        db.session.query(DocumentSegment).filter(
+            DocumentSegment.id.in_(segment_ids),
+            DocumentSegment.dataset_id == dataset_id,
+            DocumentSegment.document_id == document_id,
+        ).update(
+            {
+                "error": str(e),
+                "status": "error",
+                "disabled_at": datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
+                "enabled": False,
+            }
+        )
+        db.session.commit()
+    finally:
+        for segment in segments:
+            indexing_cache_key = "segment_{}_indexing".format(segment.id)
+            redis_client.delete(indexing_cache_key)

+ 1 - 1
api/tasks/remove_document_from_index_task.py

@@ -43,7 +43,7 @@ def remove_document_from_index_task(document_id: str):
         index_node_ids = [segment.index_node_id for segment in segments]
         if index_node_ids:
             try:
-                index_processor.clean(dataset, index_node_ids)
+                index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
             except Exception:
                 logging.exception(f"clean dataset {dataset.id} from index failed")
 

+ 7 - 7
api/tasks/retry_document_indexing_task.py

@@ -48,7 +48,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
             if document:
                 document.indexing_status = "error"
                 document.error = str(e)
-                document.stopped_at = datetime.datetime.utcnow()
+                document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
                 db.session.add(document)
                 db.session.commit()
             redis_client.delete(retry_indexing_cache_key)
@@ -69,14 +69,14 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
             if segments:
                 index_node_ids = [segment.index_node_id for segment in segments]
                 # delete from vector index
-                index_processor.clean(dataset, index_node_ids)
+                index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
 
-                for segment in segments:
-                    db.session.delete(segment)
-                db.session.commit()
+            for segment in segments:
+                db.session.delete(segment)
+            db.session.commit()
 
             document.indexing_status = "parsing"
-            document.processing_started_at = datetime.datetime.utcnow()
+            document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
             db.session.add(document)
             db.session.commit()
 
@@ -86,7 +86,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
         except Exception as ex:
             document.indexing_status = "error"
             document.error = str(ex)
-            document.stopped_at = datetime.datetime.utcnow()
+            document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
             db.session.add(document)
             db.session.commit()
             logging.info(click.style(str(ex), fg="yellow"))

+ 7 - 7
api/tasks/sync_website_document_indexing_task.py

@@ -46,7 +46,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
         if document:
             document.indexing_status = "error"
             document.error = str(e)
-            document.stopped_at = datetime.datetime.utcnow()
+            document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
             db.session.add(document)
             db.session.commit()
         redis_client.delete(sync_indexing_cache_key)
@@ -65,14 +65,14 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
         if segments:
             index_node_ids = [segment.index_node_id for segment in segments]
             # delete from vector index
-            index_processor.clean(dataset, index_node_ids)
+            index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
 
-            for segment in segments:
-                db.session.delete(segment)
-            db.session.commit()
+        for segment in segments:
+            db.session.delete(segment)
+        db.session.commit()
 
         document.indexing_status = "parsing"
-        document.processing_started_at = datetime.datetime.utcnow()
+        document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
         db.session.add(document)
         db.session.commit()
 
@@ -82,7 +82,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
     except Exception as ex:
         document.indexing_status = "error"
         document.error = str(ex)
-        document.stopped_at = datetime.datetime.utcnow()
+        document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
         db.session.add(document)
         db.session.commit()
         logging.info(click.style(str(ex), fg="yellow"))

+ 98 - 0
api/templates/clean_document_job_mail_template-US.html

@@ -0,0 +1,98 @@
+<!DOCTYPE html>
+<html lang="en">
+<head>
+  <meta charset="UTF-8">
+  <meta name="viewport" content="width=device-width, initial-scale=1.0">
+  <title>Documents Disabled Notification</title>
+  <style>
+    body {
+      font-family: Arial, sans-serif;
+      margin: 0;
+      padding: 0;
+      background-color: #f5f5f5;
+    }
+    .email-container {
+      max-width: 600px;
+      margin: 20px auto;
+      background: #ffffff;
+      border-radius: 10px;
+      box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
+      overflow: hidden;
+    }
+    .header {
+      background-color: #eef2fa;
+      padding: 20px;
+      text-align: center;
+    }
+    .header img {
+      height: 40px;
+    }
+    .content {
+      padding: 20px;
+      line-height: 1.6;
+      color: #333;
+    }
+    .content h1 {
+      font-size: 24px;
+      color: #222;
+    }
+    .content p {
+      margin: 10px 0;
+    }
+    .content ul {
+      padding-left: 20px;
+    }
+    .content ul li {
+      margin-bottom: 10px;
+    }
+    .cta-button {
+      display: block;
+      margin: 20px auto;
+      padding: 10px 20px;
+      background-color: #4e89f9;
+      color: #ffffff;
+      text-align: center;
+      text-decoration: none;
+      border-radius: 5px;
+      width: fit-content;
+    }
+    .footer {
+      text-align: center;
+      padding: 10px;
+      font-size: 12px;
+      color: #777;
+      background-color: #f9f9f9;
+    }
+  </style>
+</head>
+<body>
+  <div class="email-container">
+    <!-- Header -->
+    <div class="header">
+      <img src="https://via.placeholder.com/150x40?text=Dify" alt="Dify Logo">
+    </div>
+
+    <!-- Content -->
+    <div class="content">
+      <h1>Some Documents in Your Knowledge Base Have Been Disabled</h1>
+      <p>Dear {{userName}},</p>
+      <p>
+        We're sorry for the inconvenience. To ensure optimal performance, documents 
+        that haven’t been updated or accessed in the past 7 days have been disabled in 
+        your knowledge bases:
+      </p>
+      <ul>
+        {{knowledge_details}}
+      </ul>
+      <p>You can re-enable them anytime.</p>
+      <a href={{url}} class="cta-button">Re-enable in Dify</a>
+    </div>
+
+    <!-- Footer -->
+    <div class="footer">
+      Sincerely,<br>
+      The Dify Team
+    </div>
+  </div>
+</body>
+</html>