Forráskód Böngészése

Feat/dify rag (#2528)

Co-authored-by: jyong <jyong@dify.ai>
Jyong 1 éve
szülő
commit
6c4e6bf1d6
100 módosított fájl, 3097 hozzáadás és 5550 törlés
  1. BIN
      api/celerybeat-schedule.db
  2. 2 1
      api/config.py
  3. 26 6
      api/controllers/console/datasets/data_source.py
  4. 41 33
      api/controllers/console/datasets/datasets.py
  5. 43 29
      api/controllers/console/datasets/datasets_document.py
  6. 0 107
      api/core/data_loader/file_extractor.py
  7. 0 34
      api/core/data_loader/loader/html.py
  8. 0 55
      api/core/data_loader/loader/pdf.py
  9. 1 1
      api/core/docstore/dataset_docstore.py
  10. 7 31
      api/core/features/annotation_reply.py
  11. 0 51
      api/core/index/index.py
  12. 0 305
      api/core/index/vector_index/base.py
  13. 0 165
      api/core/index/vector_index/milvus_vector_index.py
  14. 0 229
      api/core/index/vector_index/qdrant_vector_index.py
  15. 0 90
      api/core/index/vector_index/vector_index.py
  16. 0 179
      api/core/index/vector_index/weaviate_vector_index.py
  17. 134 252
      api/core/indexing_runner.py
  18. 0 0
      api/core/rag/__init__.py
  19. 38 0
      api/core/rag/cleaner/clean_processor.py
  20. 12 0
      api/core/rag/cleaner/cleaner_base.py
  21. 12 0
      api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py
  22. 15 0
      api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py
  23. 12 0
      api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py
  24. 11 0
      api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py
  25. 11 0
      api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py
  26. 0 0
      api/core/rag/data_post_processor/__init__.py
  27. 49 0
      api/core/rag/data_post_processor/data_post_processor.py
  28. 19 0
      api/core/rag/data_post_processor/reorder.py
  29. 0 0
      api/core/rag/datasource/__init__.py
  30. 21 0
      api/core/rag/datasource/entity/embedding.py
  31. 0 0
      api/core/rag/datasource/keyword/__init__.py
  32. 0 0
      api/core/rag/datasource/keyword/jieba/__init__.py
  33. 18 90
      api/core/rag/datasource/keyword/jieba/jieba.py
  34. 1 1
      api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py
  35. 0 0
      api/core/rag/datasource/keyword/jieba/stopwords.py
  36. 5 23
      api/core/rag/datasource/keyword/keyword_base.py
  37. 60 0
      api/core/rag/datasource/keyword/keyword_factory.py
  38. 165 0
      api/core/rag/datasource/retrieval_service.py
  39. 0 0
      api/core/rag/datasource/vdb/__init__.py
  40. 10 0
      api/core/rag/datasource/vdb/field.py
  41. 0 0
      api/core/rag/datasource/vdb/milvus/__init__.py
  42. 214 0
      api/core/rag/datasource/vdb/milvus/milvus_vector.py
  43. 0 0
      api/core/rag/datasource/vdb/qdrant/__init__.py
  44. 360 0
      api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
  45. 62 0
      api/core/rag/datasource/vdb/vector_base.py
  46. 171 0
      api/core/rag/datasource/vdb/vector_factory.py
  47. 0 0
      api/core/rag/datasource/vdb/weaviate/__init__.py
  48. 235 0
      api/core/rag/datasource/vdb/weaviate/weaviate_vector.py
  49. 166 0
      api/core/rag/extractor/blod/blod.py
  50. 71 0
      api/core/rag/extractor/csv_extractor.py
  51. 6 0
      api/core/rag/extractor/entity/datasource_type.py
  52. 36 0
      api/core/rag/extractor/entity/extract_setting.py
  53. 14 9
      api/core/rag/extractor/excel_extractor.py
  54. 139 0
      api/core/rag/extractor/extract_processor.py
  55. 12 0
      api/core/rag/extractor/extractor_base.py
  56. 46 0
      api/core/rag/extractor/helpers.py
  57. 24 20
      api/core/rag/extractor/html_extractor.py
  58. 14 26
      api/core/rag/extractor/markdown_extractor.py
  59. 23 37
      api/core/rag/extractor/notion_extractor.py
  60. 72 0
      api/core/rag/extractor/pdf_extractor.py
  61. 50 0
      api/core/rag/extractor/text_extractor.py
  62. 61 0
      api/core/rag/extractor/unstructured/unstructured_doc_extractor.py
  63. 5 4
      api/core/rag/extractor/unstructured/unstructured_eml_extractor.py
  64. 4 4
      api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py
  65. 4 4
      api/core/rag/extractor/unstructured/unstructured_msg_extractor.py
  66. 8 7
      api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py
  67. 9 7
      api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py
  68. 4 4
      api/core/rag/extractor/unstructured/unstructured_text_extractor.py
  69. 4 4
      api/core/rag/extractor/unstructured/unstructured_xml_extractor.py
  70. 62 0
      api/core/rag/extractor/word_extractor.py
  71. 0 0
      api/core/rag/index_processor/__init__.py
  72. 0 0
      api/core/rag/index_processor/constant/__init__.py
  73. 8 0
      api/core/rag/index_processor/constant/index_type.py
  74. 70 0
      api/core/rag/index_processor/index_processor_base.py
  75. 28 0
      api/core/rag/index_processor/index_processor_factory.py
  76. 0 0
      api/core/rag/index_processor/processor/__init__.py
  77. 92 0
      api/core/rag/index_processor/processor/paragraph_index_processor.py
  78. 161 0
      api/core/rag/index_processor/processor/qa_index_processor.py
  79. 0 0
      api/core/rag/models/__init__.py
  80. 16 0
      api/core/rag/models/document.py
  81. 5 5
      api/core/tool/web_reader_tool.py
  82. 16 71
      api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py
  83. 17 95
      api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py
  84. 5 5
      api/core/tools/utils/web_reader_tool.py
  85. 0 56
      api/core/vector_store/milvus_vector_store.py
  86. 0 76
      api/core/vector_store/qdrant_vector_store.py
  87. 0 852
      api/core/vector_store/vector/milvus.py
  88. 0 1759
      api/core/vector_store/vector/qdrant.py
  89. 0 506
      api/core/vector_store/vector/weaviate.py
  90. 0 38
      api/core/vector_store/weaviate_vector_store.py
  91. 1 1
      api/events/event_handlers/clean_when_dataset_deleted.py
  92. 2 1
      api/events/event_handlers/clean_when_document_deleted.py
  93. 8 0
      api/models/dataset.py
  94. 4 13
      api/schedule/clean_unused_datasets_task.py
  95. 21 14
      api/services/dataset_service.py
  96. 5 4
      api/services/file_service.py
  97. 13 62
      api/services/hit_testing_service.py
  98. 0 119
      api/services/retrieval_service.py
  99. 31 54
      api/services/vector_service.py
  100. 5 11
      api/tasks/add_document_to_index_task.py

BIN
api/celerybeat-schedule.db


+ 2 - 1
api/config.py

@@ -56,6 +56,7 @@ DEFAULTS = {
     'BILLING_ENABLED': 'False',
     'BILLING_ENABLED': 'False',
     'CAN_REPLACE_LOGO': 'False',
     'CAN_REPLACE_LOGO': 'False',
     'ETL_TYPE': 'dify',
     'ETL_TYPE': 'dify',
+    'KEYWORD_STORE': 'jieba',
     'BATCH_UPLOAD_LIMIT': 20
     'BATCH_UPLOAD_LIMIT': 20
 }
 }
 
 
@@ -183,7 +184,7 @@ class Config:
         # Currently, only support: qdrant, milvus, zilliz, weaviate
         # Currently, only support: qdrant, milvus, zilliz, weaviate
         # ------------------------
         # ------------------------
         self.VECTOR_STORE = get_env('VECTOR_STORE')
         self.VECTOR_STORE = get_env('VECTOR_STORE')
-
+        self.KEYWORD_STORE = get_env('KEYWORD_STORE')
         # qdrant settings
         # qdrant settings
         self.QDRANT_URL = get_env('QDRANT_URL')
         self.QDRANT_URL = get_env('QDRANT_URL')
         self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
         self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')

+ 26 - 6
api/controllers/console/datasets/data_source.py

@@ -9,8 +9,9 @@ from werkzeug.exceptions import NotFound
 from controllers.console import api
 from controllers.console import api
 from controllers.console.setup import setup_required
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
 from controllers.console.wraps import account_initialization_required
-from core.data_loader.loader.notion import NotionLoader
 from core.indexing_runner import IndexingRunner
 from core.indexing_runner import IndexingRunner
+from core.rag.extractor.entity.extract_setting import ExtractSetting
+from core.rag.extractor.notion_extractor import NotionExtractor
 from extensions.ext_database import db
 from extensions.ext_database import db
 from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
 from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
 from libs.login import login_required
 from libs.login import login_required
@@ -173,14 +174,14 @@ class DataSourceNotionApi(Resource):
         if not data_source_binding:
         if not data_source_binding:
             raise NotFound('Data source binding not found.')
             raise NotFound('Data source binding not found.')
 
 
-        loader = NotionLoader(
-            notion_access_token=data_source_binding.access_token,
+        extractor = NotionExtractor(
             notion_workspace_id=workspace_id,
             notion_workspace_id=workspace_id,
             notion_obj_id=page_id,
             notion_obj_id=page_id,
-            notion_page_type=page_type
+            notion_page_type=page_type,
+            notion_access_token=data_source_binding.access_token
         )
         )
 
 
-        text_docs = loader.load()
+        text_docs = extractor.extract()
         return {
         return {
             'content': "\n".join([doc.page_content for doc in text_docs])
             'content': "\n".join([doc.page_content for doc in text_docs])
         }, 200
         }, 200
@@ -192,11 +193,30 @@ class DataSourceNotionApi(Resource):
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
         parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json')
         parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json')
         parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
         parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
+        parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
+        parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
         args = parser.parse_args()
         args = parser.parse_args()
         # validate args
         # validate args
         DocumentService.estimate_args_validate(args)
         DocumentService.estimate_args_validate(args)
+        notion_info_list = args['notion_info_list']
+        extract_settings = []
+        for notion_info in notion_info_list:
+            workspace_id = notion_info['workspace_id']
+            for page in notion_info['pages']:
+                extract_setting = ExtractSetting(
+                    datasource_type="notion_import",
+                    notion_info={
+                        "notion_workspace_id": workspace_id,
+                        "notion_obj_id": page['page_id'],
+                        "notion_page_type": page['type']
+                    },
+                    document_model=args['doc_form']
+                )
+                extract_settings.append(extract_setting)
         indexing_runner = IndexingRunner()
         indexing_runner = IndexingRunner()
-        response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, args['notion_info_list'], args['process_rule'])
+        response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
+                                                     args['process_rule'], args['doc_form'],
+                                                     args['doc_language'])
         return response, 200
         return response, 200
 
 
 
 

+ 41 - 33
api/controllers/console/datasets/datasets.py

@@ -15,6 +15,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
 from core.indexing_runner import IndexingRunner
 from core.indexing_runner import IndexingRunner
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.entities.model_entities import ModelType
 from core.provider_manager import ProviderManager
 from core.provider_manager import ProviderManager
+from core.rag.extractor.entity.extract_setting import ExtractSetting
 from extensions.ext_database import db
 from extensions.ext_database import db
 from fields.app_fields import related_app_list
 from fields.app_fields import related_app_list
 from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
 from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
@@ -178,9 +179,9 @@ class DatasetApi(Resource):
                             location='json', store_missing=False,
                             location='json', store_missing=False,
                             type=_validate_description_length)
                             type=_validate_description_length)
         parser.add_argument('indexing_technique', type=str, location='json',
         parser.add_argument('indexing_technique', type=str, location='json',
-                    choices=Dataset.INDEXING_TECHNIQUE_LIST,
-                    nullable=True,
-                    help='Invalid indexing technique.')
+                            choices=Dataset.INDEXING_TECHNIQUE_LIST,
+                            nullable=True,
+                            help='Invalid indexing technique.')
         parser.add_argument('permission', type=str, location='json', choices=(
         parser.add_argument('permission', type=str, location='json', choices=(
             'only_me', 'all_team_members'), help='Invalid permission.')
             'only_me', 'all_team_members'), help='Invalid permission.')
         parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
         parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
@@ -258,7 +259,7 @@ class DatasetIndexingEstimateApi(Resource):
         parser = reqparse.RequestParser()
         parser = reqparse.RequestParser()
         parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
         parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
         parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
         parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
-        parser.add_argument('indexing_technique', type=str, required=True, 
+        parser.add_argument('indexing_technique', type=str, required=True,
                             choices=Dataset.INDEXING_TECHNIQUE_LIST,
                             choices=Dataset.INDEXING_TECHNIQUE_LIST,
                             nullable=True, location='json')
                             nullable=True, location='json')
         parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
         parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
@@ -268,6 +269,7 @@ class DatasetIndexingEstimateApi(Resource):
         args = parser.parse_args()
         args = parser.parse_args()
         # validate args
         # validate args
         DocumentService.estimate_args_validate(args)
         DocumentService.estimate_args_validate(args)
+        extract_settings = []
         if args['info_list']['data_source_type'] == 'upload_file':
         if args['info_list']['data_source_type'] == 'upload_file':
             file_ids = args['info_list']['file_info_list']['file_ids']
             file_ids = args['info_list']['file_info_list']['file_ids']
             file_details = db.session.query(UploadFile).filter(
             file_details = db.session.query(UploadFile).filter(
@@ -278,37 +280,44 @@ class DatasetIndexingEstimateApi(Resource):
             if file_details is None:
             if file_details is None:
                 raise NotFound("File not found.")
                 raise NotFound("File not found.")
 
 
-            indexing_runner = IndexingRunner()
-
-            try:
-                response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
-                                                                  args['process_rule'], args['doc_form'],
-                                                                  args['doc_language'], args['dataset_id'],
-                                                                  args['indexing_technique'])
-            except LLMBadRequestError:
-                raise ProviderNotInitializeError(
-                    "No Embedding Model available. Please configure a valid provider "
-                    "in the Settings -> Model Provider.")
-            except ProviderTokenNotInitError as ex:
-                raise ProviderNotInitializeError(ex.description)
+            if file_details:
+                for file_detail in file_details:
+                    extract_setting = ExtractSetting(
+                        datasource_type="upload_file",
+                        upload_file=file_detail,
+                        document_model=args['doc_form']
+                    )
+                    extract_settings.append(extract_setting)
         elif args['info_list']['data_source_type'] == 'notion_import':
         elif args['info_list']['data_source_type'] == 'notion_import':
-
-            indexing_runner = IndexingRunner()
-
-            try:
-                response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
-                                                                    args['info_list']['notion_info_list'],
-                                                                    args['process_rule'], args['doc_form'],
-                                                                    args['doc_language'], args['dataset_id'],
-                                                                    args['indexing_technique'])
-            except LLMBadRequestError:
-                raise ProviderNotInitializeError(
-                    "No Embedding Model available. Please configure a valid provider "
-                    "in the Settings -> Model Provider.")
-            except ProviderTokenNotInitError as ex:
-                raise ProviderNotInitializeError(ex.description)
+            notion_info_list = args['info_list']['notion_info_list']
+            for notion_info in notion_info_list:
+                workspace_id = notion_info['workspace_id']
+                for page in notion_info['pages']:
+                    extract_setting = ExtractSetting(
+                        datasource_type="notion_import",
+                        notion_info={
+                            "notion_workspace_id": workspace_id,
+                            "notion_obj_id": page['page_id'],
+                            "notion_page_type": page['type']
+                        },
+                        document_model=args['doc_form']
+                    )
+                    extract_settings.append(extract_setting)
         else:
         else:
             raise ValueError('Data source type not support')
             raise ValueError('Data source type not support')
+        indexing_runner = IndexingRunner()
+        try:
+            response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
+                                                         args['process_rule'], args['doc_form'],
+                                                         args['doc_language'], args['dataset_id'],
+                                                         args['indexing_technique'])
+        except LLMBadRequestError:
+            raise ProviderNotInitializeError(
+                "No Embedding Model available. Please configure a valid provider "
+                "in the Settings -> Model Provider.")
+        except ProviderTokenNotInitError as ex:
+            raise ProviderNotInitializeError(ex.description)
+
         return response, 200
         return response, 200
 
 
 
 
@@ -508,4 +517,3 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
 api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
 api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
 api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
 api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
 api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
 api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
-

+ 43 - 29
api/controllers/console/datasets/datasets_document.py

@@ -32,6 +32,7 @@ from core.indexing_runner import IndexingRunner
 from core.model_manager import ModelManager
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.errors.invoke import InvokeAuthorizationError
 from core.model_runtime.errors.invoke import InvokeAuthorizationError
+from core.rag.extractor.entity.extract_setting import ExtractSetting
 from extensions.ext_database import db
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from extensions.ext_redis import redis_client
 from fields.document_fields import (
 from fields.document_fields import (
@@ -95,7 +96,7 @@ class GetProcessRuleApi(Resource):
         req_data = request.args
         req_data = request.args
 
 
         document_id = req_data.get('document_id')
         document_id = req_data.get('document_id')
-        
+
         # get default rules
         # get default rules
         mode = DocumentService.DEFAULT_RULES['mode']
         mode = DocumentService.DEFAULT_RULES['mode']
         rules = DocumentService.DEFAULT_RULES['rules']
         rules = DocumentService.DEFAULT_RULES['rules']
@@ -362,12 +363,18 @@ class DocumentIndexingEstimateApi(DocumentResource):
                 if not file:
                 if not file:
                     raise NotFound('File not found.')
                     raise NotFound('File not found.')
 
 
+                extract_setting = ExtractSetting(
+                    datasource_type="upload_file",
+                    upload_file=file,
+                    document_model=document.doc_form
+                )
+
                 indexing_runner = IndexingRunner()
                 indexing_runner = IndexingRunner()
 
 
                 try:
                 try:
-                    response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file],
-                                                                      data_process_rule_dict, None,
-                                                                      'English', dataset_id)
+                    response = indexing_runner.indexing_estimate(current_user.current_tenant_id, [extract_setting],
+                                                                 data_process_rule_dict, document.doc_form,
+                                                                 'English', dataset_id)
                 except LLMBadRequestError:
                 except LLMBadRequestError:
                     raise ProviderNotInitializeError(
                     raise ProviderNotInitializeError(
                         "No Embedding Model available. Please configure a valid provider "
                         "No Embedding Model available. Please configure a valid provider "
@@ -402,6 +409,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
         data_process_rule = documents[0].dataset_process_rule
         data_process_rule = documents[0].dataset_process_rule
         data_process_rule_dict = data_process_rule.to_dict()
         data_process_rule_dict = data_process_rule.to_dict()
         info_list = []
         info_list = []
+        extract_settings = []
         for document in documents:
         for document in documents:
             if document.indexing_status in ['completed', 'error']:
             if document.indexing_status in ['completed', 'error']:
                 raise DocumentAlreadyFinishedError()
                 raise DocumentAlreadyFinishedError()
@@ -424,42 +432,48 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
                 }
                 }
                 info_list.append(notion_info)
                 info_list.append(notion_info)
 
 
-        if dataset.data_source_type == 'upload_file':
-            file_details = db.session.query(UploadFile).filter(
-                UploadFile.tenant_id == current_user.current_tenant_id,
-                UploadFile.id.in_(info_list)
-            ).all()
+            if document.data_source_type == 'upload_file':
+                file_id = data_source_info['upload_file_id']
+                file_detail = db.session.query(UploadFile).filter(
+                    UploadFile.tenant_id == current_user.current_tenant_id,
+                    UploadFile.id == file_id
+                ).first()
 
 
-            if file_details is None:
-                raise NotFound("File not found.")
+                if file_detail is None:
+                    raise NotFound("File not found.")
 
 
-            indexing_runner = IndexingRunner()
-            try:
-                response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
-                                                                  data_process_rule_dict, None,
-                                                                  'English', dataset_id)
-            except LLMBadRequestError:
-                raise ProviderNotInitializeError(
-                    "No Embedding Model available. Please configure a valid provider "
-                    "in the Settings -> Model Provider.")
-            except ProviderTokenNotInitError as ex:
-                raise ProviderNotInitializeError(ex.description)
-        elif dataset.data_source_type == 'notion_import':
+                extract_setting = ExtractSetting(
+                    datasource_type="upload_file",
+                    upload_file=file_detail,
+                    document_model=document.doc_form
+                )
+                extract_settings.append(extract_setting)
+
+            elif document.data_source_type == 'notion_import':
+                extract_setting = ExtractSetting(
+                    datasource_type="notion_import",
+                    notion_info={
+                        "notion_workspace_id": data_source_info['notion_workspace_id'],
+                        "notion_obj_id": data_source_info['notion_page_id'],
+                        "notion_page_type": data_source_info['type']
+                    },
+                    document_model=document.doc_form
+                )
+                extract_settings.append(extract_setting)
 
 
+            else:
+                raise ValueError('Data source type not support')
             indexing_runner = IndexingRunner()
             indexing_runner = IndexingRunner()
             try:
             try:
-                response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
-                                                                    info_list,
-                                                                    data_process_rule_dict,
-                                                                    None, 'English', dataset_id)
+                response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
+                                                             data_process_rule_dict, document.doc_form,
+                                                             'English', dataset_id)
             except LLMBadRequestError:
             except LLMBadRequestError:
                 raise ProviderNotInitializeError(
                 raise ProviderNotInitializeError(
                     "No Embedding Model available. Please configure a valid provider "
                     "No Embedding Model available. Please configure a valid provider "
                     "in the Settings -> Model Provider.")
                     "in the Settings -> Model Provider.")
             except ProviderTokenNotInitError as ex:
             except ProviderTokenNotInitError as ex:
                 raise ProviderNotInitializeError(ex.description)
                 raise ProviderNotInitializeError(ex.description)
-        else:
-            raise ValueError('Data source type not support')
         return response
         return response
 
 
 
 

+ 0 - 107
api/core/data_loader/file_extractor.py

@@ -1,107 +0,0 @@
-import tempfile
-from pathlib import Path
-from typing import Optional, Union
-
-import requests
-from flask import current_app
-from langchain.document_loaders import Docx2txtLoader, TextLoader
-from langchain.schema import Document
-
-from core.data_loader.loader.csv_loader import CSVLoader
-from core.data_loader.loader.excel import ExcelLoader
-from core.data_loader.loader.html import HTMLLoader
-from core.data_loader.loader.markdown import MarkdownLoader
-from core.data_loader.loader.pdf import PdfLoader
-from core.data_loader.loader.unstructured.unstructured_eml import UnstructuredEmailLoader
-from core.data_loader.loader.unstructured.unstructured_markdown import UnstructuredMarkdownLoader
-from core.data_loader.loader.unstructured.unstructured_msg import UnstructuredMsgLoader
-from core.data_loader.loader.unstructured.unstructured_ppt import UnstructuredPPTLoader
-from core.data_loader.loader.unstructured.unstructured_pptx import UnstructuredPPTXLoader
-from core.data_loader.loader.unstructured.unstructured_text import UnstructuredTextLoader
-from core.data_loader.loader.unstructured.unstructured_xml import UnstructuredXmlLoader
-from extensions.ext_storage import storage
-from models.model import UploadFile
-
-SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain']
-USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
-
-
-class FileExtractor:
-    @classmethod
-    def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[list[Document], str]:
-        with tempfile.TemporaryDirectory() as temp_dir:
-            suffix = Path(upload_file.key).suffix
-            file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
-            storage.download(upload_file.key, file_path)
-
-            return cls.load_from_file(file_path, return_text, upload_file, is_automatic)
-
-    @classmethod
-    def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]:
-        response = requests.get(url, headers={
-            "User-Agent": USER_AGENT
-        })
-
-        with tempfile.TemporaryDirectory() as temp_dir:
-            suffix = Path(url).suffix
-            file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
-            with open(file_path, 'wb') as file:
-                file.write(response.content)
-
-            return cls.load_from_file(file_path, return_text)
-
-    @classmethod
-    def load_from_file(cls, file_path: str, return_text: bool = False,
-                       upload_file: Optional[UploadFile] = None,
-                       is_automatic: bool = False) -> Union[list[Document], str]:
-        input_file = Path(file_path)
-        delimiter = '\n'
-        file_extension = input_file.suffix.lower()
-        etl_type = current_app.config['ETL_TYPE']
-        unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL']
-        if etl_type == 'Unstructured':
-            if file_extension == '.xlsx':
-                loader = ExcelLoader(file_path)
-            elif file_extension == '.pdf':
-                loader = PdfLoader(file_path, upload_file=upload_file)
-            elif file_extension in ['.md', '.markdown']:
-                loader = UnstructuredMarkdownLoader(file_path, unstructured_api_url) if is_automatic \
-                    else MarkdownLoader(file_path, autodetect_encoding=True)
-            elif file_extension in ['.htm', '.html']:
-                loader = HTMLLoader(file_path)
-            elif file_extension in ['.docx']:
-                loader = Docx2txtLoader(file_path)
-            elif file_extension == '.csv':
-                loader = CSVLoader(file_path, autodetect_encoding=True)
-            elif file_extension == '.msg':
-                loader = UnstructuredMsgLoader(file_path, unstructured_api_url)
-            elif file_extension == '.eml':
-                loader = UnstructuredEmailLoader(file_path, unstructured_api_url)
-            elif file_extension == '.ppt':
-                loader = UnstructuredPPTLoader(file_path, unstructured_api_url)
-            elif file_extension == '.pptx':
-                loader = UnstructuredPPTXLoader(file_path, unstructured_api_url)
-            elif file_extension == '.xml':
-                loader = UnstructuredXmlLoader(file_path, unstructured_api_url)
-            else:
-                # txt
-                loader = UnstructuredTextLoader(file_path, unstructured_api_url) if is_automatic \
-                    else TextLoader(file_path, autodetect_encoding=True)
-        else:
-            if file_extension == '.xlsx':
-                loader = ExcelLoader(file_path)
-            elif file_extension == '.pdf':
-                loader = PdfLoader(file_path, upload_file=upload_file)
-            elif file_extension in ['.md', '.markdown']:
-                loader = MarkdownLoader(file_path, autodetect_encoding=True)
-            elif file_extension in ['.htm', '.html']:
-                loader = HTMLLoader(file_path)
-            elif file_extension in ['.docx']:
-                loader = Docx2txtLoader(file_path)
-            elif file_extension == '.csv':
-                loader = CSVLoader(file_path, autodetect_encoding=True)
-            else:
-                # txt
-                loader = TextLoader(file_path, autodetect_encoding=True)
-
-        return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load()

+ 0 - 34
api/core/data_loader/loader/html.py

@@ -1,34 +0,0 @@
-import logging
-
-from bs4 import BeautifulSoup
-from langchain.document_loaders.base import BaseLoader
-from langchain.schema import Document
-
-logger = logging.getLogger(__name__)
-
-
-class HTMLLoader(BaseLoader):
-    """Load html files.
-
-
-    Args:
-        file_path: Path to the file to load.
-    """
-
-    def __init__(
-        self,
-        file_path: str
-    ):
-        """Initialize with file path."""
-        self._file_path = file_path
-
-    def load(self) -> list[Document]:
-        return [Document(page_content=self._load_as_text())]
-
-    def _load_as_text(self) -> str:
-        with open(self._file_path, "rb") as fp:
-            soup = BeautifulSoup(fp, 'html.parser')
-            text = soup.get_text()
-            text = text.strip() if text else ''
-
-        return text

+ 0 - 55
api/core/data_loader/loader/pdf.py

@@ -1,55 +0,0 @@
-import logging
-from typing import Optional
-
-from langchain.document_loaders import PyPDFium2Loader
-from langchain.document_loaders.base import BaseLoader
-from langchain.schema import Document
-
-from extensions.ext_storage import storage
-from models.model import UploadFile
-
-logger = logging.getLogger(__name__)
-
-
-class PdfLoader(BaseLoader):
-    """Load pdf files.
-
-
-    Args:
-        file_path: Path to the file to load.
-    """
-
-    def __init__(
-        self,
-        file_path: str,
-        upload_file: Optional[UploadFile] = None
-    ):
-        """Initialize with file path."""
-        self._file_path = file_path
-        self._upload_file = upload_file
-
-    def load(self) -> list[Document]:
-        plaintext_file_key = ''
-        plaintext_file_exists = False
-        if self._upload_file:
-            if self._upload_file.hash:
-                plaintext_file_key = 'upload_files/' + self._upload_file.tenant_id + '/' \
-                                     + self._upload_file.hash + '.0625.plaintext'
-                try:
-                    text = storage.load(plaintext_file_key).decode('utf-8')
-                    plaintext_file_exists = True
-                    return [Document(page_content=text)]
-                except FileNotFoundError:
-                    pass
-        documents = PyPDFium2Loader(file_path=self._file_path).load()
-        text_list = []
-        for document in documents:
-            text_list.append(document.page_content)
-        text = "\n\n".join(text_list)
-
-        # save plaintext file for caching
-        if not plaintext_file_exists and plaintext_file_key:
-            storage.save(plaintext_file_key, text.encode('utf-8'))
-
-        return documents
-

+ 1 - 1
api/core/docstore/dataset_docstore.py

@@ -1,12 +1,12 @@
 from collections.abc import Sequence
 from collections.abc import Sequence
 from typing import Any, Optional, cast
 from typing import Any, Optional, cast
 
 
-from langchain.schema import Document
 from sqlalchemy import func
 from sqlalchemy import func
 
 
 from core.model_manager import ModelManager
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
 from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+from core.rag.models.document import Document
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.dataset import Dataset, DocumentSegment
 from models.dataset import Dataset, DocumentSegment
 
 

+ 7 - 31
api/core/features/annotation_reply.py

@@ -1,13 +1,8 @@
 import logging
 import logging
 from typing import Optional
 from typing import Optional
 
 
-from flask import current_app
-
-from core.embedding.cached_embedding import CacheEmbedding
 from core.entities.application_entities import InvokeFrom
 from core.entities.application_entities import InvokeFrom
-from core.index.vector_index.vector_index import VectorIndex
-from core.model_manager import ModelManager
-from core.model_runtime.entities.model_entities import ModelType
+from core.rag.datasource.vdb.vector_factory import Vector
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.dataset import Dataset
 from models.dataset import Dataset
 from models.model import App, AppAnnotationSetting, Message, MessageAnnotation
 from models.model import App, AppAnnotationSetting, Message, MessageAnnotation
@@ -45,17 +40,6 @@ class AnnotationReplyFeature:
             embedding_provider_name = collection_binding_detail.provider_name
             embedding_provider_name = collection_binding_detail.provider_name
             embedding_model_name = collection_binding_detail.model_name
             embedding_model_name = collection_binding_detail.model_name
 
 
-            model_manager = ModelManager()
-            model_instance = model_manager.get_model_instance(
-                tenant_id=app_record.tenant_id,
-                provider=embedding_provider_name,
-                model_type=ModelType.TEXT_EMBEDDING,
-                model=embedding_model_name
-            )
-
-            # get embedding model
-            embeddings = CacheEmbedding(model_instance)
-
             dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
             dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
                 embedding_provider_name,
                 embedding_provider_name,
                 embedding_model_name,
                 embedding_model_name,
@@ -71,22 +55,14 @@ class AnnotationReplyFeature:
                 collection_binding_id=dataset_collection_binding.id
                 collection_binding_id=dataset_collection_binding.id
             )
             )
 
 
-            vector_index = VectorIndex(
-                dataset=dataset,
-                config=current_app.config,
-                embeddings=embeddings,
-                attributes=['doc_id', 'annotation_id', 'app_id']
-            )
+            vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
 
 
-            documents = vector_index.search(
+            documents = vector.search_by_vector(
                 query=query,
                 query=query,
-                search_type='similarity_score_threshold',
-                search_kwargs={
-                    'k': 1,
-                    'score_threshold': score_threshold,
-                    'filter': {
-                        'group_id': [dataset.id]
-                    }
+                k=1,
+                score_threshold=score_threshold,
+                filter={
+                    'group_id': [dataset.id]
                 }
                 }
             )
             )
 
 

+ 0 - 51
api/core/index/index.py

@@ -1,51 +0,0 @@
-from flask import current_app
-from langchain.embeddings import OpenAIEmbeddings
-
-from core.embedding.cached_embedding import CacheEmbedding
-from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex
-from core.index.vector_index.vector_index import VectorIndex
-from core.model_manager import ModelManager
-from core.model_runtime.entities.model_entities import ModelType
-from models.dataset import Dataset
-
-
-class IndexBuilder:
-    @classmethod
-    def get_index(cls, dataset: Dataset, indexing_technique: str, ignore_high_quality_check: bool = False):
-        if indexing_technique == "high_quality":
-            if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
-                return None
-
-            model_manager = ModelManager()
-            embedding_model = model_manager.get_model_instance(
-                tenant_id=dataset.tenant_id,
-                model_type=ModelType.TEXT_EMBEDDING,
-                provider=dataset.embedding_model_provider,
-                model=dataset.embedding_model
-            )
-
-            embeddings = CacheEmbedding(embedding_model)
-
-            return VectorIndex(
-                dataset=dataset,
-                config=current_app.config,
-                embeddings=embeddings
-            )
-        elif indexing_technique == "economy":
-            return KeywordTableIndex(
-                dataset=dataset,
-                config=KeywordTableConfig(
-                    max_keywords_per_chunk=10
-                )
-            )
-        else:
-            raise ValueError('Unknown indexing technique')
-
-    @classmethod
-    def get_default_high_quality_index(cls, dataset: Dataset):
-        embeddings = OpenAIEmbeddings(openai_api_key=' ')
-        return VectorIndex(
-            dataset=dataset,
-            config=current_app.config,
-            embeddings=embeddings
-        )

+ 0 - 305
api/core/index/vector_index/base.py

@@ -1,305 +0,0 @@
-import json
-import logging
-from abc import abstractmethod
-from typing import Any, cast
-
-from langchain.embeddings.base import Embeddings
-from langchain.schema import BaseRetriever, Document
-from langchain.vectorstores import VectorStore
-
-from core.index.base import BaseIndex
-from extensions.ext_database import db
-from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
-from models.dataset import Document as DatasetDocument
-
-
-class BaseVectorIndex(BaseIndex):
-
-    def __init__(self, dataset: Dataset, embeddings: Embeddings):
-        super().__init__(dataset)
-        self._embeddings = embeddings
-        self._vector_store = None
-
-    def get_type(self) -> str:
-        raise NotImplementedError
-
-    @abstractmethod
-    def get_index_name(self, dataset: Dataset) -> str:
-        raise NotImplementedError
-
-    @abstractmethod
-    def to_index_struct(self) -> dict:
-        raise NotImplementedError
-
-    @abstractmethod
-    def _get_vector_store(self) -> VectorStore:
-        raise NotImplementedError
-
-    @abstractmethod
-    def _get_vector_store_class(self) -> type:
-        raise NotImplementedError
-
-    @abstractmethod
-    def search_by_full_text_index(
-            self, query: str,
-            **kwargs: Any
-    ) -> list[Document]:
-        raise NotImplementedError
-
-    def search(
-            self, query: str,
-            **kwargs: Any
-    ) -> list[Document]:
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity'
-        search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
-
-        if search_type == 'similarity_score_threshold':
-            score_threshold = search_kwargs.get("score_threshold")
-            if (score_threshold is None) or (not isinstance(score_threshold, float)):
-                search_kwargs['score_threshold'] = .0
-
-            docs_with_similarity = vector_store.similarity_search_with_relevance_scores(
-                query, **search_kwargs
-            )
-
-            docs = []
-            for doc, similarity in docs_with_similarity:
-                doc.metadata['score'] = similarity
-                docs.append(doc)
-
-            return docs
-
-        # similarity k
-        # mmr k, fetch_k, lambda_mult
-        # similarity_score_threshold k
-        return vector_store.as_retriever(
-            search_type=search_type,
-            search_kwargs=search_kwargs
-        ).get_relevant_documents(query)
-
-    def get_retriever(self, **kwargs: Any) -> BaseRetriever:
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        return vector_store.as_retriever(**kwargs)
-
-    def add_texts(self, texts: list[Document], **kwargs):
-        if self._is_origin():
-            self.recreate_dataset(self.dataset)
-
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        if kwargs.get('duplicate_check', False):
-            texts = self._filter_duplicate_texts(texts)
-
-        uuids = self._get_uuids(texts)
-        vector_store.add_documents(texts, uuids=uuids)
-
-    def text_exists(self, id: str) -> bool:
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        return vector_store.text_exists(id)
-
-    def delete_by_ids(self, ids: list[str]) -> None:
-        if self._is_origin():
-            self.recreate_dataset(self.dataset)
-            return
-
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        for node_id in ids:
-            vector_store.del_text(node_id)
-
-    def delete_by_group_id(self, group_id: str) -> None:
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-        if self.dataset.collection_binding_id:
-            vector_store.delete_by_group_id(group_id)
-        else:
-            vector_store.delete()
-
-    def delete(self) -> None:
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        vector_store.delete()
-
-    def _is_origin(self):
-        return False
-
-    def recreate_dataset(self, dataset: Dataset):
-        logging.info(f"Recreating dataset {dataset.id}")
-
-        try:
-            self.delete()
-        except Exception as e:
-            raise e
-
-        dataset_documents = db.session.query(DatasetDocument).filter(
-            DatasetDocument.dataset_id == dataset.id,
-            DatasetDocument.indexing_status == 'completed',
-            DatasetDocument.enabled == True,
-            DatasetDocument.archived == False,
-        ).all()
-
-        documents = []
-        for dataset_document in dataset_documents:
-            segments = db.session.query(DocumentSegment).filter(
-                DocumentSegment.document_id == dataset_document.id,
-                DocumentSegment.status == 'completed',
-                DocumentSegment.enabled == True
-            ).all()
-
-            for segment in segments:
-                document = Document(
-                    page_content=segment.content,
-                    metadata={
-                        "doc_id": segment.index_node_id,
-                        "doc_hash": segment.index_node_hash,
-                        "document_id": segment.document_id,
-                        "dataset_id": segment.dataset_id,
-                    }
-                )
-
-                documents.append(document)
-
-        origin_index_struct = self.dataset.index_struct[:]
-        self.dataset.index_struct = None
-
-        if documents:
-            try:
-                self.create(documents)
-            except Exception as e:
-                self.dataset.index_struct = origin_index_struct
-                raise e
-
-            dataset.index_struct = json.dumps(self.to_index_struct())
-
-        db.session.commit()
-
-        self.dataset = dataset
-        logging.info(f"Dataset {dataset.id} recreate successfully.")
-
-    def create_qdrant_dataset(self, dataset: Dataset):
-        logging.info(f"create_qdrant_dataset {dataset.id}")
-
-        try:
-            self.delete()
-        except Exception as e:
-            raise e
-
-        dataset_documents = db.session.query(DatasetDocument).filter(
-            DatasetDocument.dataset_id == dataset.id,
-            DatasetDocument.indexing_status == 'completed',
-            DatasetDocument.enabled == True,
-            DatasetDocument.archived == False,
-        ).all()
-
-        documents = []
-        for dataset_document in dataset_documents:
-            segments = db.session.query(DocumentSegment).filter(
-                DocumentSegment.document_id == dataset_document.id,
-                DocumentSegment.status == 'completed',
-                DocumentSegment.enabled == True
-            ).all()
-
-            for segment in segments:
-                document = Document(
-                    page_content=segment.content,
-                    metadata={
-                        "doc_id": segment.index_node_id,
-                        "doc_hash": segment.index_node_hash,
-                        "document_id": segment.document_id,
-                        "dataset_id": segment.dataset_id,
-                    }
-                )
-
-                documents.append(document)
-
-        if documents:
-            try:
-                self.create(documents)
-            except Exception as e:
-                raise e
-
-        logging.info(f"Dataset {dataset.id} recreate successfully.")
-
-    def update_qdrant_dataset(self, dataset: Dataset):
-        logging.info(f"update_qdrant_dataset {dataset.id}")
-
-        segment = db.session.query(DocumentSegment).filter(
-            DocumentSegment.dataset_id == dataset.id,
-            DocumentSegment.status == 'completed',
-            DocumentSegment.enabled == True
-        ).first()
-
-        if segment:
-            try:
-                exist = self.text_exists(segment.index_node_id)
-                if exist:
-                    index_struct = {
-                        "type": 'qdrant',
-                        "vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
-                    }
-                    dataset.index_struct = json.dumps(index_struct)
-                    db.session.commit()
-            except Exception as e:
-                raise e
-
-        logging.info(f"Dataset {dataset.id} recreate successfully.")
-
-    def restore_dataset_in_one(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
-        logging.info(f"restore dataset in_one,_dataset {dataset.id}")
-
-        dataset_documents = db.session.query(DatasetDocument).filter(
-            DatasetDocument.dataset_id == dataset.id,
-            DatasetDocument.indexing_status == 'completed',
-            DatasetDocument.enabled == True,
-            DatasetDocument.archived == False,
-        ).all()
-
-        documents = []
-        for dataset_document in dataset_documents:
-            segments = db.session.query(DocumentSegment).filter(
-                DocumentSegment.document_id == dataset_document.id,
-                DocumentSegment.status == 'completed',
-                DocumentSegment.enabled == True
-            ).all()
-
-            for segment in segments:
-                document = Document(
-                    page_content=segment.content,
-                    metadata={
-                        "doc_id": segment.index_node_id,
-                        "doc_hash": segment.index_node_hash,
-                        "document_id": segment.document_id,
-                        "dataset_id": segment.dataset_id,
-                    }
-                )
-
-                documents.append(document)
-
-        if documents:
-            try:
-                self.add_texts(documents)
-            except Exception as e:
-                raise e
-
-        logging.info(f"Dataset {dataset.id} recreate successfully.")
-
-    def delete_original_collection(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
-        logging.info(f"delete original collection: {dataset.id}")
-
-        self.delete()
-
-        dataset.collection_binding_id = dataset_collection_binding.id
-        db.session.add(dataset)
-        db.session.commit()
-
-        logging.info(f"Dataset {dataset.id} recreate successfully.")

+ 0 - 165
api/core/index/vector_index/milvus_vector_index.py

@@ -1,165 +0,0 @@
-from typing import Any, cast
-
-from langchain.embeddings.base import Embeddings
-from langchain.schema import Document
-from langchain.vectorstores import VectorStore
-from pydantic import BaseModel, root_validator
-
-from core.index.base import BaseIndex
-from core.index.vector_index.base import BaseVectorIndex
-from core.vector_store.milvus_vector_store import MilvusVectorStore
-from models.dataset import Dataset
-
-
-class MilvusConfig(BaseModel):
-    host: str
-    port: int
-    user: str
-    password: str
-    secure: bool = False
-    batch_size: int = 100
-
-    @root_validator()
-    def validate_config(cls, values: dict) -> dict:
-        if not values['host']:
-            raise ValueError("config MILVUS_HOST is required")
-        if not values['port']:
-            raise ValueError("config MILVUS_PORT is required")
-        if not values['user']:
-            raise ValueError("config MILVUS_USER is required")
-        if not values['password']:
-            raise ValueError("config MILVUS_PASSWORD is required")
-        return values
-
-    def to_milvus_params(self):
-        return {
-            'host': self.host,
-            'port': self.port,
-            'user': self.user,
-            'password': self.password,
-            'secure': self.secure
-        }
-
-
-class MilvusVectorIndex(BaseVectorIndex):
-    def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings):
-        super().__init__(dataset, embeddings)
-        self._client_config = config
-
-    def get_type(self) -> str:
-        return 'milvus'
-
-    def get_index_name(self, dataset: Dataset) -> str:
-        if self.dataset.index_struct_dict:
-            class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
-            if not class_prefix.endswith('_Node'):
-                # original class_prefix
-                class_prefix += '_Node'
-
-            return class_prefix
-
-        dataset_id = dataset.id
-        return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
-
-    def to_index_struct(self) -> dict:
-        return {
-            "type": self.get_type(),
-            "vector_store": {"class_prefix": self.get_index_name(self.dataset)}
-        }
-
-    def create(self, texts: list[Document], **kwargs) -> BaseIndex:
-        uuids = self._get_uuids(texts)
-        index_params = {
-            'metric_type': 'IP',
-            'index_type': "HNSW",
-            'params': {"M": 8, "efConstruction": 64}
-        }
-        self._vector_store = MilvusVectorStore.from_documents(
-            texts,
-            self._embeddings,
-            collection_name=self.get_index_name(self.dataset),
-            connection_args=self._client_config.to_milvus_params(),
-            index_params=index_params
-        )
-
-        return self
-
-    def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
-        uuids = self._get_uuids(texts)
-        self._vector_store = MilvusVectorStore.from_documents(
-            texts,
-            self._embeddings,
-            collection_name=collection_name,
-            ids=uuids,
-            content_payload_key='page_content'
-        )
-
-        return self
-
-    def _get_vector_store(self) -> VectorStore:
-        """Only for created index."""
-        if self._vector_store:
-            return self._vector_store
-
-        return MilvusVectorStore(
-            collection_name=self.get_index_name(self.dataset),
-            embedding_function=self._embeddings,
-            connection_args=self._client_config.to_milvus_params()
-        )
-
-    def _get_vector_store_class(self) -> type:
-        return MilvusVectorStore
-
-    def delete_by_document_id(self, document_id: str):
-
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-        ids = vector_store.get_ids_by_document_id(document_id)
-        if ids:
-            vector_store.del_texts({
-                'filter': f'id in {ids}'
-            })
-
-    def delete_by_metadata_field(self, key: str, value: str):
-
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-        ids = vector_store.get_ids_by_metadata_field(key, value)
-        if ids:
-            vector_store.del_texts({
-                'filter': f'id in {ids}'
-            })
-
-    def delete_by_ids(self, doc_ids: list[str]) -> None:
-
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-        ids = vector_store.get_ids_by_doc_ids(doc_ids)
-        vector_store.del_texts({
-            'filter': f' id in {ids}'
-        })
-
-    def delete_by_group_id(self, group_id: str) -> None:
-
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        vector_store.delete()
-
-    def delete(self) -> None:
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        from qdrant_client.http import models
-        vector_store.del_texts(models.Filter(
-            must=[
-                models.FieldCondition(
-                    key="group_id",
-                    match=models.MatchValue(value=self.dataset.id),
-                ),
-            ],
-        ))
-
-    def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
-        # milvus/zilliz doesn't support bm25 search
-        return []

+ 0 - 229
api/core/index/vector_index/qdrant_vector_index.py

@@ -1,229 +0,0 @@
-import os
-from typing import Any, Optional, cast
-
-import qdrant_client
-from langchain.embeddings.base import Embeddings
-from langchain.schema import Document
-from langchain.vectorstores import VectorStore
-from pydantic import BaseModel
-from qdrant_client.http.models import HnswConfigDiff
-
-from core.index.base import BaseIndex
-from core.index.vector_index.base import BaseVectorIndex
-from core.vector_store.qdrant_vector_store import QdrantVectorStore
-from extensions.ext_database import db
-from models.dataset import Dataset, DatasetCollectionBinding
-
-
-class QdrantConfig(BaseModel):
-    endpoint: str
-    api_key: Optional[str]
-    timeout: float = 20
-    root_path: Optional[str]
-
-    def to_qdrant_params(self):
-        if self.endpoint and self.endpoint.startswith('path:'):
-            path = self.endpoint.replace('path:', '')
-            if not os.path.isabs(path):
-                path = os.path.join(self.root_path, path)
-
-            return {
-                'path': path
-            }
-        else:
-            return {
-                'url': self.endpoint,
-                'api_key': self.api_key,
-                'timeout': self.timeout
-            }
-
-
-class QdrantVectorIndex(BaseVectorIndex):
-    def __init__(self, dataset: Dataset, config: QdrantConfig, embeddings: Embeddings):
-        super().__init__(dataset, embeddings)
-        self._client_config = config
-
-    def get_type(self) -> str:
-        return 'qdrant'
-
-    def get_index_name(self, dataset: Dataset) -> str:
-        if dataset.collection_binding_id:
-            dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
-                filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
-                one_or_none()
-            if dataset_collection_binding:
-                return dataset_collection_binding.collection_name
-            else:
-                raise ValueError('Dataset Collection Bindings is not exist!')
-        else:
-            if self.dataset.index_struct_dict:
-                class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
-                return class_prefix
-
-            dataset_id = dataset.id
-            return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
-
-    def to_index_struct(self) -> dict:
-        return {
-            "type": self.get_type(),
-            "vector_store": {"class_prefix": self.get_index_name(self.dataset)}
-        }
-
-    def create(self, texts: list[Document], **kwargs) -> BaseIndex:
-        uuids = self._get_uuids(texts)
-        self._vector_store = QdrantVectorStore.from_documents(
-            texts,
-            self._embeddings,
-            collection_name=self.get_index_name(self.dataset),
-            ids=uuids,
-            content_payload_key='page_content',
-            group_id=self.dataset.id,
-            group_payload_key='group_id',
-            hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
-                                       max_indexing_threads=0, on_disk=False),
-            **self._client_config.to_qdrant_params()
-        )
-
-        return self
-
-    def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
-        uuids = self._get_uuids(texts)
-        self._vector_store = QdrantVectorStore.from_documents(
-            texts,
-            self._embeddings,
-            collection_name=collection_name,
-            ids=uuids,
-            content_payload_key='page_content',
-            group_id=self.dataset.id,
-            group_payload_key='group_id',
-            hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
-                                       max_indexing_threads=0, on_disk=False),
-            **self._client_config.to_qdrant_params()
-        )
-
-        return self
-
-    def _get_vector_store(self) -> VectorStore:
-        """Only for created index."""
-        if self._vector_store:
-            return self._vector_store
-        attributes = ['doc_id', 'dataset_id', 'document_id']
-        client = qdrant_client.QdrantClient(
-            **self._client_config.to_qdrant_params()
-        )
-
-        return QdrantVectorStore(
-            client=client,
-            collection_name=self.get_index_name(self.dataset),
-            embeddings=self._embeddings,
-            content_payload_key='page_content',
-            group_id=self.dataset.id,
-            group_payload_key='group_id'
-        )
-
-    def _get_vector_store_class(self) -> type:
-        return QdrantVectorStore
-
-    def delete_by_document_id(self, document_id: str):
-
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        from qdrant_client.http import models
-
-        vector_store.del_texts(models.Filter(
-            must=[
-                models.FieldCondition(
-                    key="metadata.document_id",
-                    match=models.MatchValue(value=document_id),
-                ),
-            ],
-        ))
-
-    def delete_by_metadata_field(self, key: str, value: str):
-
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        from qdrant_client.http import models
-
-        vector_store.del_texts(models.Filter(
-            must=[
-                models.FieldCondition(
-                    key=f"metadata.{key}",
-                    match=models.MatchValue(value=value),
-                ),
-            ],
-        ))
-
-    def delete_by_ids(self, ids: list[str]) -> None:
-
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        from qdrant_client.http import models
-        for node_id in ids:
-            vector_store.del_texts(models.Filter(
-                must=[
-                    models.FieldCondition(
-                        key="metadata.doc_id",
-                        match=models.MatchValue(value=node_id),
-                    ),
-                ],
-            ))
-
-    def delete_by_group_id(self, group_id: str) -> None:
-
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        from qdrant_client.http import models
-        vector_store.del_texts(models.Filter(
-            must=[
-                models.FieldCondition(
-                    key="group_id",
-                    match=models.MatchValue(value=group_id),
-                ),
-            ],
-        ))
-
-    def delete(self) -> None:
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        from qdrant_client.http import models
-        vector_store.del_texts(models.Filter(
-            must=[
-                models.FieldCondition(
-                    key="group_id",
-                    match=models.MatchValue(value=self.dataset.id),
-                ),
-            ],
-        ))
-
-    def _is_origin(self):
-        if self.dataset.index_struct_dict:
-            class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
-            if not class_prefix.endswith('_Node'):
-                # original class_prefix
-                return True
-
-        return False
-
-    def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        from qdrant_client.http import models
-        return vector_store.similarity_search_by_bm25(models.Filter(
-            must=[
-                models.FieldCondition(
-                    key="group_id",
-                    match=models.MatchValue(value=self.dataset.id),
-                ),
-                models.FieldCondition(
-                    key="page_content",
-                    match=models.MatchText(text=query),
-                )
-            ],
-        ), kwargs.get('top_k', 2))

+ 0 - 90
api/core/index/vector_index/vector_index.py

@@ -1,90 +0,0 @@
-import json
-
-from flask import current_app
-from langchain.embeddings.base import Embeddings
-
-from core.index.vector_index.base import BaseVectorIndex
-from extensions.ext_database import db
-from models.dataset import Dataset, Document
-
-
-class VectorIndex:
-    def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings,
-                 attributes: list = None):
-        if attributes is None:
-            attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
-        self._dataset = dataset
-        self._embeddings = embeddings
-        self._vector_index = self._init_vector_index(dataset, config, embeddings, attributes)
-        self._attributes = attributes
-
-    def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings,
-                           attributes: list) -> BaseVectorIndex:
-        vector_type = config.get('VECTOR_STORE')
-
-        if self._dataset.index_struct_dict:
-            vector_type = self._dataset.index_struct_dict['type']
-
-        if not vector_type:
-            raise ValueError("Vector store must be specified.")
-
-        if vector_type == "weaviate":
-            from core.index.vector_index.weaviate_vector_index import WeaviateConfig, WeaviateVectorIndex
-
-            return WeaviateVectorIndex(
-                dataset=dataset,
-                config=WeaviateConfig(
-                    endpoint=config.get('WEAVIATE_ENDPOINT'),
-                    api_key=config.get('WEAVIATE_API_KEY'),
-                    batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
-                ),
-                embeddings=embeddings,
-                attributes=attributes
-            )
-        elif vector_type == "qdrant":
-            from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex
-
-            return QdrantVectorIndex(
-                dataset=dataset,
-                config=QdrantConfig(
-                    endpoint=config.get('QDRANT_URL'),
-                    api_key=config.get('QDRANT_API_KEY'),
-                    root_path=current_app.root_path,
-                    timeout=config.get('QDRANT_CLIENT_TIMEOUT')
-                ),
-                embeddings=embeddings
-            )
-        elif vector_type == "milvus":
-            from core.index.vector_index.milvus_vector_index import MilvusConfig, MilvusVectorIndex
-
-            return MilvusVectorIndex(
-                dataset=dataset,
-                config=MilvusConfig(
-                    host=config.get('MILVUS_HOST'),
-                    port=config.get('MILVUS_PORT'),
-                    user=config.get('MILVUS_USER'),
-                    password=config.get('MILVUS_PASSWORD'),
-                    secure=config.get('MILVUS_SECURE'),
-                ),
-                embeddings=embeddings
-            )
-        else:
-            raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
-
-    def add_texts(self, texts: list[Document], **kwargs):
-        if not self._dataset.index_struct_dict:
-            self._vector_index.create(texts, **kwargs)
-            self._dataset.index_struct = json.dumps(self._vector_index.to_index_struct())
-            db.session.commit()
-            return
-
-        self._vector_index.add_texts(texts, **kwargs)
-
-    def __getattr__(self, name):
-        if self._vector_index is not None:
-            method = getattr(self._vector_index, name)
-            if callable(method):
-                return method
-
-        raise AttributeError(f"'VectorIndex' object has no attribute '{name}'")
-

+ 0 - 179
api/core/index/vector_index/weaviate_vector_index.py

@@ -1,179 +0,0 @@
-from typing import Any, Optional, cast
-
-import requests
-import weaviate
-from langchain.embeddings.base import Embeddings
-from langchain.schema import Document
-from langchain.vectorstores import VectorStore
-from pydantic import BaseModel, root_validator
-
-from core.index.base import BaseIndex
-from core.index.vector_index.base import BaseVectorIndex
-from core.vector_store.weaviate_vector_store import WeaviateVectorStore
-from models.dataset import Dataset
-
-
-class WeaviateConfig(BaseModel):
-    endpoint: str
-    api_key: Optional[str]
-    batch_size: int = 100
-
-    @root_validator()
-    def validate_config(cls, values: dict) -> dict:
-        if not values['endpoint']:
-            raise ValueError("config WEAVIATE_ENDPOINT is required")
-        return values
-
-
-class WeaviateVectorIndex(BaseVectorIndex):
-
-    def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings, attributes: list):
-        super().__init__(dataset, embeddings)
-        self._client = self._init_client(config)
-        self._attributes = attributes
-
-    def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
-        auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
-
-        weaviate.connect.connection.has_grpc = False
-
-        try:
-            client = weaviate.Client(
-                url=config.endpoint,
-                auth_client_secret=auth_config,
-                timeout_config=(5, 60),
-                startup_period=None
-            )
-        except requests.exceptions.ConnectionError:
-            raise ConnectionError("Vector database connection error")
-
-        client.batch.configure(
-            # `batch_size` takes an `int` value to enable auto-batching
-            # (`None` is used for manual batching)
-            batch_size=config.batch_size,
-            # dynamically update the `batch_size` based on import speed
-            dynamic=True,
-            # `timeout_retries` takes an `int` value to retry on time outs
-            timeout_retries=3,
-        )
-
-        return client
-
-    def get_type(self) -> str:
-        return 'weaviate'
-
-    def get_index_name(self, dataset: Dataset) -> str:
-        if self.dataset.index_struct_dict:
-            class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
-            if not class_prefix.endswith('_Node'):
-                # original class_prefix
-                class_prefix += '_Node'
-
-            return class_prefix
-
-        dataset_id = dataset.id
-        return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
-
-    def to_index_struct(self) -> dict:
-        return {
-            "type": self.get_type(),
-            "vector_store": {"class_prefix": self.get_index_name(self.dataset)}
-        }
-
-    def create(self, texts: list[Document], **kwargs) -> BaseIndex:
-        uuids = self._get_uuids(texts)
-        self._vector_store = WeaviateVectorStore.from_documents(
-            texts,
-            self._embeddings,
-            client=self._client,
-            index_name=self.get_index_name(self.dataset),
-            uuids=uuids,
-            by_text=False
-        )
-
-        return self
-
-    def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
-        uuids = self._get_uuids(texts)
-        self._vector_store = WeaviateVectorStore.from_documents(
-            texts,
-            self._embeddings,
-            client=self._client,
-            index_name=self.get_index_name(self.dataset),
-            uuids=uuids,
-            by_text=False
-        )
-
-        return self
-
-
-    def _get_vector_store(self) -> VectorStore:
-        """Only for created index."""
-        if self._vector_store:
-            return self._vector_store
-
-        attributes = self._attributes
-        if self._is_origin():
-            attributes = ['doc_id']
-
-        return WeaviateVectorStore(
-            client=self._client,
-            index_name=self.get_index_name(self.dataset),
-            text_key='text',
-            embedding=self._embeddings,
-            attributes=attributes,
-            by_text=False
-        )
-
-    def _get_vector_store_class(self) -> type:
-        return WeaviateVectorStore
-
-    def delete_by_document_id(self, document_id: str):
-        if self._is_origin():
-            self.recreate_dataset(self.dataset)
-            return
-
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        vector_store.del_texts({
-            "operator": "Equal",
-            "path": ["document_id"],
-            "valueText": document_id
-        })
-
-    def delete_by_metadata_field(self, key: str, value: str):
-
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        vector_store.del_texts({
-            "operator": "Equal",
-            "path": [key],
-            "valueText": value
-        })
-
-    def delete_by_group_id(self, group_id: str):
-        if self._is_origin():
-            self.recreate_dataset(self.dataset)
-            return
-
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        vector_store.delete()
-
-    def _is_origin(self):
-        if self.dataset.index_struct_dict:
-            class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
-            if not class_prefix.endswith('_Node'):
-                # original class_prefix
-                return True
-
-        return False
-
-    def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-        return vector_store.similarity_search_by_bm25(query, kwargs.get('top_k', 2), **kwargs)
-

+ 134 - 252
api/core/indexing_runner.py

@@ -9,20 +9,20 @@ from typing import Optional, cast
 
 
 from flask import Flask, current_app
 from flask import Flask, current_app
 from flask_login import current_user
 from flask_login import current_user
-from langchain.schema import Document
 from langchain.text_splitter import TextSplitter
 from langchain.text_splitter import TextSplitter
 from sqlalchemy.orm.exc import ObjectDeletedError
 from sqlalchemy.orm.exc import ObjectDeletedError
 
 
-from core.data_loader.file_extractor import FileExtractor
-from core.data_loader.loader.notion import NotionLoader
 from core.docstore.dataset_docstore import DatasetDocumentStore
 from core.docstore.dataset_docstore import DatasetDocumentStore
 from core.errors.error import ProviderTokenNotInitError
 from core.errors.error import ProviderTokenNotInitError
 from core.generator.llm_generator import LLMGenerator
 from core.generator.llm_generator import LLMGenerator
-from core.index.index import IndexBuilder
 from core.model_manager import ModelInstance, ModelManager
 from core.model_manager import ModelInstance, ModelManager
 from core.model_runtime.entities.model_entities import ModelType, PriceType
 from core.model_runtime.entities.model_entities import ModelType, PriceType
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
 from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+from core.rag.extractor.entity.extract_setting import ExtractSetting
+from core.rag.index_processor.index_processor_base import BaseIndexProcessor
+from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
+from core.rag.models.document import Document
 from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter
 from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter
 from extensions.ext_database import db
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from extensions.ext_redis import redis_client
@@ -31,7 +31,6 @@ from libs import helper
 from models.dataset import Dataset, DatasetProcessRule, DocumentSegment
 from models.dataset import Dataset, DatasetProcessRule, DocumentSegment
 from models.dataset import Document as DatasetDocument
 from models.dataset import Document as DatasetDocument
 from models.model import UploadFile
 from models.model import UploadFile
-from models.source import DataSourceBinding
 from services.feature_service import FeatureService
 from services.feature_service import FeatureService
 
 
 
 
@@ -57,38 +56,19 @@ class IndexingRunner:
                 processing_rule = db.session.query(DatasetProcessRule). \
                 processing_rule = db.session.query(DatasetProcessRule). \
                     filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
                     filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
                     first()
                     first()
-
-                # load file
-                text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
-
-                # get embedding model instance
-                embedding_model_instance = None
-                if dataset.indexing_technique == 'high_quality':
-                    if dataset.embedding_model_provider:
-                        embedding_model_instance = self.model_manager.get_model_instance(
-                            tenant_id=dataset.tenant_id,
-                            provider=dataset.embedding_model_provider,
-                            model_type=ModelType.TEXT_EMBEDDING,
-                            model=dataset.embedding_model
-                        )
-                    else:
-                        embedding_model_instance = self.model_manager.get_default_model_instance(
-                            tenant_id=dataset.tenant_id,
-                            model_type=ModelType.TEXT_EMBEDDING,
-                        )
-
-                # get splitter
-                splitter = self._get_splitter(processing_rule, embedding_model_instance)
-
-                # split to documents
-                documents = self._step_split(
-                    text_docs=text_docs,
-                    splitter=splitter,
-                    dataset=dataset,
-                    dataset_document=dataset_document,
-                    processing_rule=processing_rule
-                )
-                self._build_index(
+                index_type = dataset_document.doc_form
+                index_processor = IndexProcessorFactory(index_type).init_index_processor()
+                # extract
+                text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
+
+                # transform
+                documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict())
+                # save segment
+                self._load_segments(dataset, dataset_document, documents)
+
+                # load
+                self._load(
+                    index_processor=index_processor,
                     dataset=dataset,
                     dataset=dataset,
                     dataset_document=dataset_document,
                     dataset_document=dataset_document,
                     documents=documents
                     documents=documents
@@ -134,39 +114,19 @@ class IndexingRunner:
                 filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
                 filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
                 first()
                 first()
 
 
-            # load file
-            text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
-
-            # get embedding model instance
-            embedding_model_instance = None
-            if dataset.indexing_technique == 'high_quality':
-                if dataset.embedding_model_provider:
-                    embedding_model_instance = self.model_manager.get_model_instance(
-                        tenant_id=dataset.tenant_id,
-                        provider=dataset.embedding_model_provider,
-                        model_type=ModelType.TEXT_EMBEDDING,
-                        model=dataset.embedding_model
-                    )
-                else:
-                    embedding_model_instance = self.model_manager.get_default_model_instance(
-                        tenant_id=dataset.tenant_id,
-                        model_type=ModelType.TEXT_EMBEDDING,
-                    )
-
-            # get splitter
-            splitter = self._get_splitter(processing_rule, embedding_model_instance)
+            index_type = dataset_document.doc_form
+            index_processor = IndexProcessorFactory(index_type).init_index_processor()
+            # extract
+            text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
 
 
-            # split to documents
-            documents = self._step_split(
-                text_docs=text_docs,
-                splitter=splitter,
-                dataset=dataset,
-                dataset_document=dataset_document,
-                processing_rule=processing_rule
-            )
+            # transform
+            documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict())
+            # save segment
+            self._load_segments(dataset, dataset_document, documents)
 
 
-            # build index
-            self._build_index(
+            # load
+            self._load(
+                index_processor=index_processor,
                 dataset=dataset,
                 dataset=dataset,
                 dataset_document=dataset_document,
                 dataset_document=dataset_document,
                 documents=documents
                 documents=documents
@@ -220,7 +180,15 @@ class IndexingRunner:
                         documents.append(document)
                         documents.append(document)
 
 
             # build index
             # build index
-            self._build_index(
+            # get the process rule
+            processing_rule = db.session.query(DatasetProcessRule). \
+                filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
+                first()
+
+            index_type = dataset_document.doc_form
+            index_processor = IndexProcessorFactory(index_type, processing_rule.to_dict()).init_index_processor()
+            self._load(
+                index_processor=index_processor,
                 dataset=dataset,
                 dataset=dataset,
                 dataset_document=dataset_document,
                 dataset_document=dataset_document,
                 documents=documents
                 documents=documents
@@ -239,16 +207,16 @@ class IndexingRunner:
             dataset_document.stopped_at = datetime.datetime.utcnow()
             dataset_document.stopped_at = datetime.datetime.utcnow()
             db.session.commit()
             db.session.commit()
 
 
-    def file_indexing_estimate(self, tenant_id: str, file_details: list[UploadFile], tmp_processing_rule: dict,
-                               doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
-                               indexing_technique: str = 'economy') -> dict:
+    def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict,
+                          doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
+                          indexing_technique: str = 'economy') -> dict:
         """
         """
         Estimate the indexing for the document.
         Estimate the indexing for the document.
         """
         """
         # check document limit
         # check document limit
         features = FeatureService.get_features(tenant_id)
         features = FeatureService.get_features(tenant_id)
         if features.billing.enabled:
         if features.billing.enabled:
-            count = len(file_details)
+            count = len(extract_settings)
             batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
             batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
             if count > batch_upload_limit:
             if count > batch_upload_limit:
                 raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
                 raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
@@ -284,16 +252,18 @@ class IndexingRunner:
         total_segments = 0
         total_segments = 0
         total_price = 0
         total_price = 0
         currency = 'USD'
         currency = 'USD'
-        for file_detail in file_details:
-
+        index_type = doc_form
+        index_processor = IndexProcessorFactory(index_type).init_index_processor()
+        all_text_docs = []
+        for extract_setting in extract_settings:
+            # extract
+            text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
+            all_text_docs.extend(text_docs)
             processing_rule = DatasetProcessRule(
             processing_rule = DatasetProcessRule(
                 mode=tmp_processing_rule["mode"],
                 mode=tmp_processing_rule["mode"],
                 rules=json.dumps(tmp_processing_rule["rules"])
                 rules=json.dumps(tmp_processing_rule["rules"])
             )
             )
 
 
-            # load data from file
-            text_docs = FileExtractor.load(file_detail, is_automatic=processing_rule.mode == 'automatic')
-
             # get splitter
             # get splitter
             splitter = self._get_splitter(processing_rule, embedding_model_instance)
             splitter = self._get_splitter(processing_rule, embedding_model_instance)
 
 
@@ -305,7 +275,6 @@ class IndexingRunner:
             )
             )
 
 
             total_segments += len(documents)
             total_segments += len(documents)
-
             for document in documents:
             for document in documents:
                 if len(preview_texts) < 5:
                 if len(preview_texts) < 5:
                     preview_texts.append(document.page_content)
                     preview_texts.append(document.page_content)
@@ -364,154 +333,8 @@ class IndexingRunner:
             "preview": preview_texts
             "preview": preview_texts
         }
         }
 
 
-    def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict,
-                                 doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
-                                 indexing_technique: str = 'economy') -> dict:
-        """
-        Estimate the indexing for the document.
-        """
-        # check document limit
-        features = FeatureService.get_features(tenant_id)
-        if features.billing.enabled:
-            count = len(notion_info_list)
-            batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
-            if count > batch_upload_limit:
-                raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
-
-        embedding_model_instance = None
-        if dataset_id:
-            dataset = Dataset.query.filter_by(
-                id=dataset_id
-            ).first()
-            if not dataset:
-                raise ValueError('Dataset not found.')
-            if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
-                if dataset.embedding_model_provider:
-                    embedding_model_instance = self.model_manager.get_model_instance(
-                        tenant_id=tenant_id,
-                        provider=dataset.embedding_model_provider,
-                        model_type=ModelType.TEXT_EMBEDDING,
-                        model=dataset.embedding_model
-                    )
-                else:
-                    embedding_model_instance = self.model_manager.get_default_model_instance(
-                        tenant_id=tenant_id,
-                        model_type=ModelType.TEXT_EMBEDDING,
-                    )
-        else:
-            if indexing_technique == 'high_quality':
-                embedding_model_instance = self.model_manager.get_default_model_instance(
-                    tenant_id=tenant_id,
-                    model_type=ModelType.TEXT_EMBEDDING
-                )
-        # load data from notion
-        tokens = 0
-        preview_texts = []
-        total_segments = 0
-        total_price = 0
-        currency = 'USD'
-        for notion_info in notion_info_list:
-            workspace_id = notion_info['workspace_id']
-            data_source_binding = DataSourceBinding.query.filter(
-                db.and_(
-                    DataSourceBinding.tenant_id == current_user.current_tenant_id,
-                    DataSourceBinding.provider == 'notion',
-                    DataSourceBinding.disabled == False,
-                    DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
-                )
-            ).first()
-            if not data_source_binding:
-                raise ValueError('Data source binding not found.')
-
-            for page in notion_info['pages']:
-                loader = NotionLoader(
-                    notion_access_token=data_source_binding.access_token,
-                    notion_workspace_id=workspace_id,
-                    notion_obj_id=page['page_id'],
-                    notion_page_type=page['type']
-                )
-                documents = loader.load()
-
-                processing_rule = DatasetProcessRule(
-                    mode=tmp_processing_rule["mode"],
-                    rules=json.dumps(tmp_processing_rule["rules"])
-                )
-
-                # get splitter
-                splitter = self._get_splitter(processing_rule, embedding_model_instance)
-
-                # split to documents
-                documents = self._split_to_documents_for_estimate(
-                    text_docs=documents,
-                    splitter=splitter,
-                    processing_rule=processing_rule
-                )
-                total_segments += len(documents)
-
-                embedding_model_type_instance = None
-                if embedding_model_instance:
-                    embedding_model_type_instance = embedding_model_instance.model_type_instance
-                    embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
-
-                for document in documents:
-                    if len(preview_texts) < 5:
-                        preview_texts.append(document.page_content)
-                    if indexing_technique == 'high_quality' and embedding_model_type_instance:
-                        tokens += embedding_model_type_instance.get_num_tokens(
-                            model=embedding_model_instance.model,
-                            credentials=embedding_model_instance.credentials,
-                            texts=[document.page_content]
-                        )
-
-        if doc_form and doc_form == 'qa_model':
-            model_instance = self.model_manager.get_default_model_instance(
-                tenant_id=tenant_id,
-                model_type=ModelType.LLM
-            )
-
-            model_type_instance = model_instance.model_type_instance
-            model_type_instance = cast(LargeLanguageModel, model_type_instance)
-            if len(preview_texts) > 0:
-                # qa model document
-                response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
-                                                             doc_language)
-                document_qa_list = self.format_split_text(response)
-
-                price_info = model_type_instance.get_price(
-                    model=model_instance.model,
-                    credentials=model_instance.credentials,
-                    price_type=PriceType.INPUT,
-                    tokens=total_segments * 2000,
-                )
-
-                return {
-                    "total_segments": total_segments * 20,
-                    "tokens": total_segments * 2000,
-                    "total_price": '{:f}'.format(price_info.total_amount),
-                    "currency": price_info.currency,
-                    "qa_preview": document_qa_list,
-                    "preview": preview_texts
-                }
-        if embedding_model_instance:
-            embedding_model_type_instance = embedding_model_instance.model_type_instance
-            embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
-            embedding_price_info = embedding_model_type_instance.get_price(
-                model=embedding_model_instance.model,
-                credentials=embedding_model_instance.credentials,
-                price_type=PriceType.INPUT,
-                tokens=tokens
-            )
-            total_price = '{:f}'.format(embedding_price_info.total_amount)
-            currency = embedding_price_info.currency
-        return {
-            "total_segments": total_segments,
-            "tokens": tokens,
-            "total_price": total_price,
-            "currency": currency,
-            "preview": preview_texts
-        }
-
-    def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> list[Document]:
+    def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \
+            -> list[Document]:
         # load file
         # load file
         if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
         if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
             return []
             return []
@@ -527,11 +350,27 @@ class IndexingRunner:
                 one_or_none()
                 one_or_none()
 
 
             if file_detail:
             if file_detail:
-                text_docs = FileExtractor.load(file_detail, is_automatic=automatic)
+                extract_setting = ExtractSetting(
+                    datasource_type="upload_file",
+                    upload_file=file_detail,
+                    document_model=dataset_document.doc_form
+                )
+                text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
         elif dataset_document.data_source_type == 'notion_import':
         elif dataset_document.data_source_type == 'notion_import':
-            loader = NotionLoader.from_document(dataset_document)
-            text_docs = loader.load()
-
+            if (not data_source_info or 'notion_workspace_id' not in data_source_info
+                    or 'notion_page_id' not in data_source_info):
+                raise ValueError("no notion import info found")
+            extract_setting = ExtractSetting(
+                datasource_type="notion_import",
+                notion_info={
+                    "notion_workspace_id": data_source_info['notion_workspace_id'],
+                    "notion_obj_id": data_source_info['notion_page_id'],
+                    "notion_page_type": data_source_info['notion_page_type'],
+                    "document": dataset_document
+                },
+                document_model=dataset_document.doc_form
+            )
+            text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
         # update document status to splitting
         # update document status to splitting
         self._update_document_index_status(
         self._update_document_index_status(
             document_id=dataset_document.id,
             document_id=dataset_document.id,
@@ -545,8 +384,6 @@ class IndexingRunner:
         # replace doc id to document model id
         # replace doc id to document model id
         text_docs = cast(list[Document], text_docs)
         text_docs = cast(list[Document], text_docs)
         for text_doc in text_docs:
         for text_doc in text_docs:
-            # remove invalid symbol
-            text_doc.page_content = self.filter_string(text_doc.page_content)
             text_doc.metadata['document_id'] = dataset_document.id
             text_doc.metadata['document_id'] = dataset_document.id
             text_doc.metadata['dataset_id'] = dataset_document.dataset_id
             text_doc.metadata['dataset_id'] = dataset_document.dataset_id
 
 
@@ -787,12 +624,12 @@ class IndexingRunner:
             for q, a in matches if q and a
             for q, a in matches if q and a
         ]
         ]
 
 
-    def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: list[Document]) -> None:
+    def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset,
+              dataset_document: DatasetDocument, documents: list[Document]) -> None:
         """
         """
-        Build the index for the document.
+        insert index and update document/segment status to completed
         """
         """
-        vector_index = IndexBuilder.get_index(dataset, 'high_quality')
-        keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
+
         embedding_model_instance = None
         embedding_model_instance = None
         if dataset.indexing_technique == 'high_quality':
         if dataset.indexing_technique == 'high_quality':
             embedding_model_instance = self.model_manager.get_model_instance(
             embedding_model_instance = self.model_manager.get_model_instance(
@@ -825,13 +662,8 @@ class IndexingRunner:
                     )
                     )
                     for document in chunk_documents
                     for document in chunk_documents
                 )
                 )
-
-            # save vector index
-            if vector_index:
-                vector_index.add_texts(chunk_documents)
-
-            # save keyword index
-            keyword_table_index.add_texts(chunk_documents)
+            # load index
+            index_processor.load(dataset, chunk_documents)
 
 
             document_ids = [document.metadata['doc_id'] for document in chunk_documents]
             document_ids = [document.metadata['doc_id'] for document in chunk_documents]
             db.session.query(DocumentSegment).filter(
             db.session.query(DocumentSegment).filter(
@@ -911,14 +743,64 @@ class IndexingRunner:
             )
             )
             documents.append(document)
             documents.append(document)
         # save vector index
         # save vector index
-        index = IndexBuilder.get_index(dataset, 'high_quality')
-        if index:
-            index.add_texts(documents, duplicate_check=True)
-
-        # save keyword index
-        index = IndexBuilder.get_index(dataset, 'economy')
-        if index:
-            index.add_texts(documents)
+        index_type = dataset.doc_form
+        index_processor = IndexProcessorFactory(index_type).init_index_processor()
+        index_processor.load(dataset, documents)
+
+    def _transform(self, index_processor: BaseIndexProcessor, dataset: Dataset,
+                   text_docs: list[Document], process_rule: dict) -> list[Document]:
+        # get embedding model instance
+        embedding_model_instance = None
+        if dataset.indexing_technique == 'high_quality':
+            if dataset.embedding_model_provider:
+                embedding_model_instance = self.model_manager.get_model_instance(
+                    tenant_id=dataset.tenant_id,
+                    provider=dataset.embedding_model_provider,
+                    model_type=ModelType.TEXT_EMBEDDING,
+                    model=dataset.embedding_model
+                )
+            else:
+                embedding_model_instance = self.model_manager.get_default_model_instance(
+                    tenant_id=dataset.tenant_id,
+                    model_type=ModelType.TEXT_EMBEDDING,
+                )
+
+        documents = index_processor.transform(text_docs, embedding_model_instance=embedding_model_instance,
+                                              process_rule=process_rule)
+
+        return documents
+
+    def _load_segments(self, dataset, dataset_document, documents):
+        # save node to document segment
+        doc_store = DatasetDocumentStore(
+            dataset=dataset,
+            user_id=dataset_document.created_by,
+            document_id=dataset_document.id
+        )
+
+        # add document segments
+        doc_store.add_documents(documents)
+
+        # update document status to indexing
+        cur_time = datetime.datetime.utcnow()
+        self._update_document_index_status(
+            document_id=dataset_document.id,
+            after_indexing_status="indexing",
+            extra_update_params={
+                DatasetDocument.cleaning_completed_at: cur_time,
+                DatasetDocument.splitting_completed_at: cur_time,
+            }
+        )
+
+        # update segment status to indexing
+        self._update_segments_by_document(
+            dataset_document_id=dataset_document.id,
+            update_params={
+                DocumentSegment.status: "indexing",
+                DocumentSegment.indexing_at: datetime.datetime.utcnow()
+            }
+        )
+        pass
 
 
 
 
 class DocumentIsPausedException(Exception):
 class DocumentIsPausedException(Exception):

+ 0 - 0
api/core/rag/__init__.py


+ 38 - 0
api/core/rag/cleaner/clean_processor.py

@@ -0,0 +1,38 @@
+import re
+
+
+class CleanProcessor:
+
+    @classmethod
+    def clean(cls, text: str, process_rule: dict) -> str:
+        # default clean
+        # remove invalid symbol
+        text = re.sub(r'<\|', '<', text)
+        text = re.sub(r'\|>', '>', text)
+        text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text)
+        # Unicode  U+FFFE
+        text = re.sub('\uFFFE', '', text)
+
+        rules = process_rule['rules'] if process_rule else None
+        if 'pre_processing_rules' in rules:
+            pre_processing_rules = rules["pre_processing_rules"]
+            for pre_processing_rule in pre_processing_rules:
+                if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True:
+                    # Remove extra spaces
+                    pattern = r'\n{3,}'
+                    text = re.sub(pattern, '\n\n', text)
+                    pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}'
+                    text = re.sub(pattern, ' ', text)
+                elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True:
+                    # Remove email
+                    pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)'
+                    text = re.sub(pattern, '', text)
+
+                    # Remove URL
+                    pattern = r'https?://[^\s]+'
+                    text = re.sub(pattern, '', text)
+        return text
+
+    def filter_string(self, text):
+
+        return text

+ 12 - 0
api/core/rag/cleaner/cleaner_base.py

@@ -0,0 +1,12 @@
+"""Abstract interface for document cleaner implementations."""
+from abc import ABC, abstractmethod
+
+
+class BaseCleaner(ABC):
+    """Interface for clean chunk content.
+    """
+
+    @abstractmethod
+    def clean(self, content: str):
+        raise NotImplementedError
+

+ 12 - 0
api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py

@@ -0,0 +1,12 @@
+"""Abstract interface for document clean implementations."""
+from core.rag.cleaner.cleaner_base import BaseCleaner
+
+
+class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
+
+    def clean(self, content) -> str:
+        """clean document content."""
+        from unstructured.cleaners.core import clean_extra_whitespace
+
+        # Returns "ITEM 1A: RISK FACTORS"
+        return clean_extra_whitespace(content)

+ 15 - 0
api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py

@@ -0,0 +1,15 @@
+"""Abstract interface for document clean implementations."""
+from core.rag.cleaner.cleaner_base import BaseCleaner
+
+
+class UnstructuredGroupBrokenParagraphsCleaner(BaseCleaner):
+
+    def clean(self, content) -> str:
+        """clean document content."""
+        import re
+
+        from unstructured.cleaners.core import group_broken_paragraphs
+
+        para_split_re = re.compile(r"(\s*\n\s*){3}")
+
+        return group_broken_paragraphs(content, paragraph_split=para_split_re)

+ 12 - 0
api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py

@@ -0,0 +1,12 @@
+"""Abstract interface for document clean implementations."""
+from core.rag.cleaner.cleaner_base import BaseCleaner
+
+
+class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
+
+    def clean(self, content) -> str:
+        """clean document content."""
+        from unstructured.cleaners.core import clean_non_ascii_chars
+
+        # Returns "This text containsnon-ascii characters!"
+        return clean_non_ascii_chars(content)

+ 11 - 0
api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py

@@ -0,0 +1,11 @@
+"""Abstract interface for document clean implementations."""
+from core.rag.cleaner.cleaner_base import BaseCleaner
+
+
+class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
+
+    def clean(self, content) -> str:
+        """Replaces unicode quote characters, such as the \x91 character in a string."""
+
+        from unstructured.cleaners.core import replace_unicode_quotes
+        return replace_unicode_quotes(content)

+ 11 - 0
api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py

@@ -0,0 +1,11 @@
+"""Abstract interface for document clean implementations."""
+from core.rag.cleaner.cleaner_base import BaseCleaner
+
+
+class UnstructuredTranslateTextCleaner(BaseCleaner):
+
+    def clean(self, content) -> str:
+        """clean document content."""
+        from unstructured.cleaners.translate import translate_text
+
+        return translate_text(content)

+ 0 - 0
api/core/rag/data_post_processor/__init__.py


+ 49 - 0
api/core/rag/data_post_processor/data_post_processor.py

@@ -0,0 +1,49 @@
+from typing import Optional
+
+from core.model_manager import ModelManager
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.invoke import InvokeAuthorizationError
+from core.rag.data_post_processor.reorder import ReorderRunner
+from core.rag.models.document import Document
+from core.rerank.rerank import RerankRunner
+
+
+class DataPostProcessor:
+    """Interface for data post-processing document.
+    """
+
+    def __init__(self, tenant_id: str, reranking_model: dict, reorder_enabled: bool = False):
+        self.rerank_runner = self._get_rerank_runner(reranking_model, tenant_id)
+        self.reorder_runner = self._get_reorder_runner(reorder_enabled)
+
+    def invoke(self, query: str, documents: list[Document], score_threshold: Optional[float] = None,
+               top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]:
+        if self.rerank_runner:
+            documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user)
+
+        if self.reorder_runner:
+            documents = self.reorder_runner.run(documents)
+
+        return documents
+
+    def _get_rerank_runner(self, reranking_model: dict, tenant_id: str) -> Optional[RerankRunner]:
+        if reranking_model:
+            try:
+                model_manager = ModelManager()
+                rerank_model_instance = model_manager.get_model_instance(
+                    tenant_id=tenant_id,
+                    provider=reranking_model['reranking_provider_name'],
+                    model_type=ModelType.RERANK,
+                    model=reranking_model['reranking_model_name']
+                )
+            except InvokeAuthorizationError:
+                return None
+            return RerankRunner(rerank_model_instance)
+        return None
+
+    def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]:
+        if reorder_enabled:
+            return ReorderRunner()
+        return None
+
+

+ 19 - 0
api/core/rag/data_post_processor/reorder.py

@@ -0,0 +1,19 @@
+
+from langchain.schema import Document
+
+
+class ReorderRunner:
+
+    def run(self, documents: list[Document]) -> list[Document]:
+        # Retrieve elements from odd indices (0, 2, 4, etc.) of the documents list
+        odd_elements = documents[::2]
+
+        # Retrieve elements from even indices (1, 3, 5, etc.) of the documents list
+        even_elements = documents[1::2]
+
+        # Reverse the list of elements from even indices
+        even_elements_reversed = even_elements[::-1]
+
+        new_documents = odd_elements + even_elements_reversed
+
+        return new_documents

+ 0 - 0
api/core/rag/datasource/__init__.py


+ 21 - 0
api/core/rag/datasource/entity/embedding.py

@@ -0,0 +1,21 @@
+from abc import ABC, abstractmethod
+
+
+class Embeddings(ABC):
+    """Interface for embedding models."""
+
+    @abstractmethod
+    def embed_documents(self, texts: list[str]) -> list[list[float]]:
+        """Embed search docs."""
+
+    @abstractmethod
+    def embed_query(self, text: str) -> list[float]:
+        """Embed query text."""
+
+    async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
+        """Asynchronous Embed search docs."""
+        raise NotImplementedError
+
+    async def aembed_query(self, text: str) -> list[float]:
+        """Asynchronous Embed query text."""
+        raise NotImplementedError

+ 0 - 0
api/core/rag/datasource/keyword/__init__.py


+ 0 - 0
api/core/rag/datasource/keyword/jieba/__init__.py


+ 18 - 90
api/core/index/keyword_table_index/keyword_table_index.py → api/core/rag/datasource/keyword/jieba/jieba.py

@@ -2,11 +2,11 @@ import json
 from collections import defaultdict
 from collections import defaultdict
 from typing import Any, Optional
 from typing import Any, Optional
 
 
-from langchain.schema import BaseRetriever, Document
-from pydantic import BaseModel, Extra, Field
+from pydantic import BaseModel
 
 
-from core.index.base import BaseIndex
-from core.index.keyword_table_index.jieba_keyword_table_handler import JiebaKeywordTableHandler
+from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
+from core.rag.datasource.keyword.keyword_base import BaseKeyword
+from core.rag.models.document import Document
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment
 from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment
 
 
@@ -15,59 +15,19 @@ class KeywordTableConfig(BaseModel):
     max_keywords_per_chunk: int = 10
     max_keywords_per_chunk: int = 10
 
 
 
 
-class KeywordTableIndex(BaseIndex):
-    def __init__(self, dataset: Dataset, config: KeywordTableConfig = KeywordTableConfig()):
+class Jieba(BaseKeyword):
+    def __init__(self, dataset: Dataset):
         super().__init__(dataset)
         super().__init__(dataset)
-        self._config = config
+        self._config = KeywordTableConfig()
 
 
-    def create(self, texts: list[Document], **kwargs) -> BaseIndex:
+    def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
         keyword_table_handler = JiebaKeywordTableHandler()
         keyword_table_handler = JiebaKeywordTableHandler()
-        keyword_table = {}
-        for text in texts:
-            keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
-            self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
-            keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
-
-        dataset_keyword_table = DatasetKeywordTable(
-            dataset_id=self.dataset.id,
-            keyword_table=json.dumps({
-                '__type__': 'keyword_table',
-                '__data__': {
-                    "index_id": self.dataset.id,
-                    "summary": None,
-                    "table": {}
-                }
-            }, cls=SetEncoder)
-        )
-        db.session.add(dataset_keyword_table)
-        db.session.commit()
-
-        self._save_dataset_keyword_table(keyword_table)
-
-        return self
-
-    def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
-        keyword_table_handler = JiebaKeywordTableHandler()
-        keyword_table = {}
+        keyword_table = self._get_dataset_keyword_table()
         for text in texts:
         for text in texts:
             keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
             keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
             self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
             self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
             keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
             keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
 
 
-        dataset_keyword_table = DatasetKeywordTable(
-            dataset_id=self.dataset.id,
-            keyword_table=json.dumps({
-                '__type__': 'keyword_table',
-                '__data__': {
-                    "index_id": self.dataset.id,
-                    "summary": None,
-                    "table": {}
-                }
-            }, cls=SetEncoder)
-        )
-        db.session.add(dataset_keyword_table)
-        db.session.commit()
-
         self._save_dataset_keyword_table(keyword_table)
         self._save_dataset_keyword_table(keyword_table)
 
 
         return self
         return self
@@ -76,8 +36,13 @@ class KeywordTableIndex(BaseIndex):
         keyword_table_handler = JiebaKeywordTableHandler()
         keyword_table_handler = JiebaKeywordTableHandler()
 
 
         keyword_table = self._get_dataset_keyword_table()
         keyword_table = self._get_dataset_keyword_table()
-        for text in texts:
-            keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
+        keywords_list = kwargs.get('keywords_list', None)
+        for i in range(len(texts)):
+            text = texts[i]
+            if keywords_list:
+                keywords = keywords_list[i]
+            else:
+                keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
             self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
             self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
             keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
             keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
 
 
@@ -107,20 +72,13 @@ class KeywordTableIndex(BaseIndex):
 
 
         self._save_dataset_keyword_table(keyword_table)
         self._save_dataset_keyword_table(keyword_table)
 
 
-    def delete_by_metadata_field(self, key: str, value: str):
-        pass
-
-    def get_retriever(self, **kwargs: Any) -> BaseRetriever:
-        return KeywordTableRetriever(index=self, **kwargs)
-
     def search(
     def search(
             self, query: str,
             self, query: str,
             **kwargs: Any
             **kwargs: Any
     ) -> list[Document]:
     ) -> list[Document]:
         keyword_table = self._get_dataset_keyword_table()
         keyword_table = self._get_dataset_keyword_table()
 
 
-        search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
-        k = search_kwargs.get('k') if search_kwargs.get('k') else 4
+        k = kwargs.get('top_k', 4)
 
 
         sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k)
         sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k)
 
 
@@ -150,12 +108,6 @@ class KeywordTableIndex(BaseIndex):
             db.session.delete(dataset_keyword_table)
             db.session.delete(dataset_keyword_table)
             db.session.commit()
             db.session.commit()
 
 
-    def delete_by_group_id(self, group_id: str) -> None:
-        dataset_keyword_table = self.dataset.dataset_keyword_table
-        if dataset_keyword_table:
-            db.session.delete(dataset_keyword_table)
-            db.session.commit()
-
     def _save_dataset_keyword_table(self, keyword_table):
     def _save_dataset_keyword_table(self, keyword_table):
         keyword_table_dict = {
         keyword_table_dict = {
             '__type__': 'keyword_table',
             '__type__': 'keyword_table',
@@ -242,6 +194,7 @@ class KeywordTableIndex(BaseIndex):
         ).first()
         ).first()
         if document_segment:
         if document_segment:
             document_segment.keywords = keywords
             document_segment.keywords = keywords
+            db.session.add(document_segment)
             db.session.commit()
             db.session.commit()
 
 
     def create_segment_keywords(self, node_id: str, keywords: list[str]):
     def create_segment_keywords(self, node_id: str, keywords: list[str]):
@@ -272,31 +225,6 @@ class KeywordTableIndex(BaseIndex):
         self._save_dataset_keyword_table(keyword_table)
         self._save_dataset_keyword_table(keyword_table)
 
 
 
 
-class KeywordTableRetriever(BaseRetriever, BaseModel):
-    index: KeywordTableIndex
-    search_kwargs: dict = Field(default_factory=dict)
-
-    class Config:
-        """Configuration for this pydantic object."""
-
-        extra = Extra.forbid
-        arbitrary_types_allowed = True
-
-    def get_relevant_documents(self, query: str) -> list[Document]:
-        """Get documents relevant for a query.
-
-        Args:
-            query: string to find relevant documents for
-
-        Returns:
-            List of relevant documents
-        """
-        return self.index.search(query, **self.search_kwargs)
-
-    async def aget_relevant_documents(self, query: str) -> list[Document]:
-        raise NotImplementedError("KeywordTableRetriever does not support async")
-
-
 class SetEncoder(json.JSONEncoder):
 class SetEncoder(json.JSONEncoder):
     def default(self, obj):
     def default(self, obj):
         if isinstance(obj, set):
         if isinstance(obj, set):

+ 1 - 1
api/core/index/keyword_table_index/jieba_keyword_table_handler.py → api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py

@@ -3,7 +3,7 @@ import re
 import jieba
 import jieba
 from jieba.analyse import default_tfidf
 from jieba.analyse import default_tfidf
 
 
-from core.index.keyword_table_index.stopwords import STOPWORDS
+from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
 
 
 
 
 class JiebaKeywordTableHandler:
 class JiebaKeywordTableHandler:

+ 0 - 0
api/core/index/keyword_table_index/stopwords.py → api/core/rag/datasource/keyword/jieba/stopwords.py


+ 5 - 23
api/core/index/base.py → api/core/rag/datasource/keyword/keyword_base.py

@@ -3,22 +3,17 @@ from __future__ import annotations
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from typing import Any
 from typing import Any
 
 
-from langchain.schema import BaseRetriever, Document
-
+from core.rag.models.document import Document
 from models.dataset import Dataset
 from models.dataset import Dataset
 
 
 
 
-class BaseIndex(ABC):
+class BaseKeyword(ABC):
 
 
     def __init__(self, dataset: Dataset):
     def __init__(self, dataset: Dataset):
         self.dataset = dataset
         self.dataset = dataset
 
 
     @abstractmethod
     @abstractmethod
-    def create(self, texts: list[Document], **kwargs) -> BaseIndex:
-        raise NotImplementedError
-
-    @abstractmethod
-    def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
+    def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
         raise NotImplementedError
         raise NotImplementedError
 
 
     @abstractmethod
     @abstractmethod
@@ -34,31 +29,18 @@ class BaseIndex(ABC):
         raise NotImplementedError
         raise NotImplementedError
 
 
     @abstractmethod
     @abstractmethod
-    def delete_by_metadata_field(self, key: str, value: str) -> None:
-        raise NotImplementedError
-
-    @abstractmethod
-    def delete_by_group_id(self, group_id: str) -> None:
-        raise NotImplementedError
-
-    @abstractmethod
-    def delete_by_document_id(self, document_id: str):
+    def delete_by_document_id(self, document_id: str) -> None:
         raise NotImplementedError
         raise NotImplementedError
 
 
-    @abstractmethod
-    def get_retriever(self, **kwargs: Any) -> BaseRetriever:
+    def delete(self) -> None:
         raise NotImplementedError
         raise NotImplementedError
 
 
-    @abstractmethod
     def search(
     def search(
             self, query: str,
             self, query: str,
             **kwargs: Any
             **kwargs: Any
     ) -> list[Document]:
     ) -> list[Document]:
         raise NotImplementedError
         raise NotImplementedError
 
 
-    def delete(self) -> None:
-        raise NotImplementedError
-
     def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
     def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
         for text in texts:
         for text in texts:
             doc_id = text.metadata['doc_id']
             doc_id = text.metadata['doc_id']

+ 60 - 0
api/core/rag/datasource/keyword/keyword_factory.py

@@ -0,0 +1,60 @@
+from typing import Any, cast
+
+from flask import current_app
+
+from core.rag.datasource.keyword.jieba.jieba import Jieba
+from core.rag.datasource.keyword.keyword_base import BaseKeyword
+from core.rag.models.document import Document
+from models.dataset import Dataset
+
+
+class Keyword:
+    def __init__(self, dataset: Dataset):
+        self._dataset = dataset
+        self._keyword_processor = self._init_keyword()
+
+    def _init_keyword(self) -> BaseKeyword:
+        config = cast(dict, current_app.config)
+        keyword_type = config.get('KEYWORD_STORE')
+
+        if not keyword_type:
+            raise ValueError("Keyword store must be specified.")
+
+        if keyword_type == "jieba":
+            return Jieba(
+                dataset=self._dataset
+            )
+        else:
+            raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
+
+    def create(self, texts: list[Document], **kwargs):
+        self._keyword_processor.create(texts, **kwargs)
+
+    def add_texts(self, texts: list[Document], **kwargs):
+        self._keyword_processor.add_texts(texts, **kwargs)
+
+    def text_exists(self, id: str) -> bool:
+        return self._keyword_processor.text_exists(id)
+
+    def delete_by_ids(self, ids: list[str]) -> None:
+        self._keyword_processor.delete_by_ids(ids)
+
+    def delete_by_document_id(self, document_id: str) -> None:
+        self._keyword_processor.delete_by_document_id(document_id)
+
+    def delete(self) -> None:
+        self._keyword_processor.delete()
+
+    def search(
+            self, query: str,
+            **kwargs: Any
+    ) -> list[Document]:
+        return self._keyword_processor.search(query, **kwargs)
+
+    def __getattr__(self, name):
+        if self._keyword_processor is not None:
+            method = getattr(self._keyword_processor, name)
+            if callable(method):
+                return method
+
+        raise AttributeError(f"'Keyword' object has no attribute '{name}'")

+ 165 - 0
api/core/rag/datasource/retrieval_service.py

@@ -0,0 +1,165 @@
+import threading
+from typing import Optional
+
+from flask import Flask, current_app
+from flask_login import current_user
+
+from core.rag.data_post_processor.data_post_processor import DataPostProcessor
+from core.rag.datasource.keyword.keyword_factory import Keyword
+from core.rag.datasource.vdb.vector_factory import Vector
+from extensions.ext_database import db
+from models.dataset import Dataset
+
+default_retrieval_model = {
+    'search_method': 'semantic_search',
+    'reranking_enable': False,
+    'reranking_model': {
+        'reranking_provider_name': '',
+        'reranking_model_name': ''
+    },
+    'top_k': 2,
+    'score_threshold_enabled': False
+}
+
+
+class RetrievalService:
+
+    @classmethod
+    def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
+                 top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None):
+        all_documents = []
+        threads = []
+        # retrieval_model source with keyword
+        if retrival_method == 'keyword_search':
+            keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={
+                'flask_app': current_app._get_current_object(),
+                'dataset_id': dataset_id,
+                'query': query,
+                'top_k': top_k
+            })
+            threads.append(keyword_thread)
+            keyword_thread.start()
+        # retrieval_model source with semantic
+        if retrival_method == 'semantic_search' or retrival_method == 'hybrid_search':
+            embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
+                'flask_app': current_app._get_current_object(),
+                'dataset_id': dataset_id,
+                'query': query,
+                'top_k': top_k,
+                'score_threshold': score_threshold,
+                'reranking_model': reranking_model,
+                'all_documents': all_documents,
+                'retrival_method': retrival_method
+            })
+            threads.append(embedding_thread)
+            embedding_thread.start()
+
+        # retrieval source with full text
+        if retrival_method == 'full_text_search' or retrival_method == 'hybrid_search':
+            full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
+                'flask_app': current_app._get_current_object(),
+                'dataset_id': dataset_id,
+                'query': query,
+                'retrival_method': retrival_method,
+                'score_threshold': score_threshold,
+                'top_k': top_k,
+                'reranking_model': reranking_model,
+                'all_documents': all_documents
+            })
+            threads.append(full_text_index_thread)
+            full_text_index_thread.start()
+
+        for thread in threads:
+            thread.join()
+
+        if retrival_method == 'hybrid_search':
+            data_post_processor = DataPostProcessor(str(current_user.current_tenant_id), reranking_model, False)
+            all_documents = data_post_processor.invoke(
+                query=query,
+                documents=all_documents,
+                score_threshold=score_threshold,
+                top_n=top_k
+            )
+        return all_documents
+
+    @classmethod
+    def keyword_search(cls, flask_app: Flask, dataset_id: str, query: str,
+                       top_k: int, all_documents: list):
+        with flask_app.app_context():
+            dataset = db.session.query(Dataset).filter(
+                Dataset.id == dataset_id
+            ).first()
+
+            keyword = Keyword(
+                dataset=dataset
+            )
+
+            documents = keyword.search(
+                query,
+                k=top_k
+            )
+            all_documents.extend(documents)
+
+    @classmethod
+    def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str,
+                         top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
+                         all_documents: list, retrival_method: str):
+        with flask_app.app_context():
+            dataset = db.session.query(Dataset).filter(
+                Dataset.id == dataset_id
+            ).first()
+
+            vector = Vector(
+                dataset=dataset
+            )
+
+            documents = vector.search_by_vector(
+                query,
+                search_type='similarity_score_threshold',
+                k=top_k,
+                score_threshold=score_threshold,
+                filter={
+                    'group_id': [dataset.id]
+                }
+            )
+
+            if documents:
+                if reranking_model and retrival_method == 'semantic_search':
+                    data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
+                    all_documents.extend(data_post_processor.invoke(
+                        query=query,
+                        documents=documents,
+                        score_threshold=score_threshold,
+                        top_n=len(documents)
+                    ))
+                else:
+                    all_documents.extend(documents)
+
+    @classmethod
+    def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str,
+                               top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
+                               all_documents: list, retrival_method: str):
+        with flask_app.app_context():
+            dataset = db.session.query(Dataset).filter(
+                Dataset.id == dataset_id
+            ).first()
+
+            vector_processor = Vector(
+                dataset=dataset,
+            )
+
+            documents = vector_processor.search_by_full_text(
+                query,
+                top_k=top_k
+            )
+            if documents:
+                if reranking_model and retrival_method == 'full_text_search':
+                    data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
+                    all_documents.extend(data_post_processor.invoke(
+                        query=query,
+                        documents=documents,
+                        score_threshold=score_threshold,
+                        top_n=len(documents)
+                    ))
+                else:
+                    all_documents.extend(documents)

+ 0 - 0
api/core/rag/datasource/vdb/__init__.py


+ 10 - 0
api/core/rag/datasource/vdb/field.py

@@ -0,0 +1,10 @@
+from enum import Enum
+
+
+class Field(Enum):
+    CONTENT_KEY = "page_content"
+    METADATA_KEY = "metadata"
+    GROUP_KEY = "group_id"
+    VECTOR = "vector"
+    TEXT_KEY = "text"
+    PRIMARY_KEY = " id"

+ 0 - 0
api/core/rag/datasource/vdb/milvus/__init__.py


+ 214 - 0
api/core/rag/datasource/vdb/milvus/milvus_vector.py

@@ -0,0 +1,214 @@
+import logging
+from typing import Any, Optional
+from uuid import uuid4
+
+from pydantic import BaseModel, root_validator
+from pymilvus import MilvusClient, MilvusException, connections
+
+from core.rag.datasource.vdb.field import Field
+from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.models.document import Document
+
+logger = logging.getLogger(__name__)
+
+
+class MilvusConfig(BaseModel):
+    host: str
+    port: int
+    user: str
+    password: str
+    secure: bool = False
+    batch_size: int = 100
+
+    @root_validator()
+    def validate_config(cls, values: dict) -> dict:
+        if not values['host']:
+            raise ValueError("config MILVUS_HOST is required")
+        if not values['port']:
+            raise ValueError("config MILVUS_PORT is required")
+        if not values['user']:
+            raise ValueError("config MILVUS_USER is required")
+        if not values['password']:
+            raise ValueError("config MILVUS_PASSWORD is required")
+        return values
+
+    def to_milvus_params(self):
+        return {
+            'host': self.host,
+            'port': self.port,
+            'user': self.user,
+            'password': self.password,
+            'secure': self.secure
+        }
+
+
+class MilvusVector(BaseVector):
+
+    def __init__(self, collection_name: str, config: MilvusConfig):
+        super().__init__(collection_name)
+        self._client_config = config
+        self._client = self._init_client(config)
+        self._consistency_level = 'Session'
+        self._fields = []
+
+    def get_type(self) -> str:
+        return 'milvus'
+
+    def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
+        index_params = {
+            'metric_type': 'IP',
+            'index_type': "HNSW",
+            'params': {"M": 8, "efConstruction": 64}
+        }
+        metadatas = [d.metadata for d in texts]
+
+        # Grab the existing collection if it exists
+        from pymilvus import utility
+        alias = uuid4().hex
+        if self._client_config.secure:
+            uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
+        else:
+            uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
+        connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
+        if not utility.has_collection(self._collection_name, using=alias):
+            self.create_collection(embeddings, metadatas, index_params)
+        self.add_texts(texts, embeddings)
+
+    def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
+        insert_dict_list = []
+        for i in range(len(documents)):
+            insert_dict = {
+                Field.CONTENT_KEY.value: documents[i].page_content,
+                Field.VECTOR.value: embeddings[i],
+                Field.METADATA_KEY.value: documents[i].metadata
+            }
+            insert_dict_list.append(insert_dict)
+        # Total insert count
+        total_count = len(insert_dict_list)
+
+        pks: list[str] = []
+
+        for i in range(0, total_count, 1000):
+            batch_insert_list = insert_dict_list[i:i + 1000]
+            # Insert into the collection.
+            try:
+                ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
+                pks.extend(ids)
+            except MilvusException as e:
+                logger.error(
+                    "Failed to insert batch starting at entity: %s/%s", i, total_count
+                )
+                raise e
+        return pks
+
+    def delete_by_document_id(self, document_id: str):
+
+        ids = self.get_ids_by_metadata_field('document_id', document_id)
+        if ids:
+            self._client.delete(collection_name=self._collection_name, pks=ids)
+
+    def get_ids_by_metadata_field(self, key: str, value: str):
+        result = self._client.query(collection_name=self._collection_name,
+                                    filter=f'metadata["{key}"] == "{value}"',
+                                    output_fields=["id"])
+        if result:
+            return [item["id"] for item in result]
+        else:
+            return None
+
+    def delete_by_metadata_field(self, key: str, value: str):
+
+        ids = self.get_ids_by_metadata_field(key, value)
+        if ids:
+            self._client.delete(collection_name=self._collection_name, pks=ids)
+
+    def delete_by_ids(self, doc_ids: list[str]) -> None:
+
+        self._client.delete(collection_name=self._collection_name, pks=doc_ids)
+
+    def delete(self) -> None:
+
+        from pymilvus import utility
+        utility.drop_collection(self._collection_name, None)
+
+    def text_exists(self, id: str) -> bool:
+
+        result = self._client.query(collection_name=self._collection_name,
+                                    filter=f'metadata["doc_id"] == "{id}"',
+                                    output_fields=["id"])
+
+        return len(result) > 0
+
+    def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
+
+        # Set search parameters.
+        results = self._client.search(collection_name=self._collection_name,
+                                      data=[query_vector],
+                                      limit=kwargs.get('top_k', 4),
+                                      output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
+                                      )
+        # Organize results.
+        docs = []
+        for result in results[0]:
+            metadata = result['entity'].get(Field.METADATA_KEY.value)
+            metadata['score'] = result['distance']
+            score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
+            if result['distance'] > score_threshold:
+                doc = Document(page_content=result['entity'].get(Field.CONTENT_KEY.value),
+                               metadata=metadata)
+                docs.append(doc)
+        return docs
+
+    def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
+        # milvus/zilliz doesn't support bm25 search
+        return []
+
+    def create_collection(
+            self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
+    ) -> str:
+        from pymilvus import CollectionSchema, DataType, FieldSchema
+        from pymilvus.orm.types import infer_dtype_bydata
+
+        # Determine embedding dim
+        dim = len(embeddings[0])
+        fields = []
+        if metadatas:
+            fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
+
+        # Create the text field
+        fields.append(
+            FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)
+        )
+        # Create the primary key field
+        fields.append(
+            FieldSchema(
+                Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True
+            )
+        )
+        # Create the vector field, supports binary or float vectors
+        fields.append(
+            FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)
+        )
+
+        # Create the schema for the collection
+        schema = CollectionSchema(fields)
+
+        for x in schema.fields:
+            self._fields.append(x.name)
+        # Since primary field is auto-id, no need to track it
+        self._fields.remove(Field.PRIMARY_KEY.value)
+
+        # Create the collection
+        collection_name = self._collection_name
+        self._client.create_collection_with_schema(collection_name=collection_name,
+                                                   schema=schema, index_param=index_params,
+                                                   consistency_level=self._consistency_level)
+        return collection_name
+
+    def _init_client(self, config) -> MilvusClient:
+        if config.secure:
+            uri = "https://" + str(config.host) + ":" + str(config.port)
+        else:
+            uri = "http://" + str(config.host) + ":" + str(config.port)
+        client = MilvusClient(uri=uri, user=config.user, password=config.password)
+        return client

+ 0 - 0
api/core/rag/datasource/vdb/qdrant/__init__.py


+ 360 - 0
api/core/rag/datasource/vdb/qdrant/qdrant_vector.py

@@ -0,0 +1,360 @@
+import os
+import uuid
+from collections.abc import Generator, Iterable, Sequence
+from itertools import islice
+from typing import TYPE_CHECKING, Any, Optional, Union, cast
+
+import qdrant_client
+from pydantic import BaseModel
+from qdrant_client.http import models as rest
+from qdrant_client.http.models import (
+    FilterSelector,
+    HnswConfigDiff,
+    PayloadSchemaType,
+    TextIndexParams,
+    TextIndexType,
+    TokenizerType,
+)
+from qdrant_client.local.qdrant_local import QdrantLocal
+
+from core.rag.datasource.vdb.field import Field
+from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.models.document import Document
+
+if TYPE_CHECKING:
+    from qdrant_client import grpc  # noqa
+    from qdrant_client.conversions import common_types
+    from qdrant_client.http import models as rest
+
+    DictFilter = dict[str, Union[str, int, bool, dict, list]]
+    MetadataFilter = Union[DictFilter, common_types.Filter]
+
+
+class QdrantConfig(BaseModel):
+    endpoint: str
+    api_key: Optional[str]
+    timeout: float = 20
+    root_path: Optional[str]
+
+    def to_qdrant_params(self):
+        if self.endpoint and self.endpoint.startswith('path:'):
+            path = self.endpoint.replace('path:', '')
+            if not os.path.isabs(path):
+                path = os.path.join(self.root_path, path)
+
+            return {
+                'path': path
+            }
+        else:
+            return {
+                'url': self.endpoint,
+                'api_key': self.api_key,
+                'timeout': self.timeout
+            }
+
+
+class QdrantVector(BaseVector):
+
+    def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = 'Cosine'):
+        super().__init__(collection_name)
+        self._client_config = config
+        self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params())
+        self._distance_func = distance_func.upper()
+        self._group_id = group_id
+
+    def get_type(self) -> str:
+        return 'qdrant'
+
+    def to_index_struct(self) -> dict:
+        return {
+            "type": self.get_type(),
+            "vector_store": {"class_prefix": self._collection_name}
+        }
+
+    def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
+        if texts:
+            # get embedding vector size
+            vector_size = len(embeddings[0])
+            # get collection name
+            collection_name = self._collection_name
+            collection_name = collection_name or uuid.uuid4().hex
+            all_collection_name = []
+            collections_response = self._client.get_collections()
+            collection_list = collections_response.collections
+            for collection in collection_list:
+                all_collection_name.append(collection.name)
+            if collection_name not in all_collection_name:
+                # create collection
+                self.create_collection(collection_name, vector_size)
+
+            self.add_texts(texts, embeddings, **kwargs)
+
+    def create_collection(self, collection_name: str, vector_size: int):
+        from qdrant_client.http import models as rest
+        vectors_config = rest.VectorParams(
+            size=vector_size,
+            distance=rest.Distance[self._distance_func],
+        )
+        hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
+                                     max_indexing_threads=0, on_disk=False)
+        self._client.recreate_collection(
+            collection_name=collection_name,
+            vectors_config=vectors_config,
+            hnsw_config=hnsw_config,
+            timeout=int(self._client_config.timeout),
+        )
+
+        # create payload index
+        self._client.create_payload_index(collection_name, Field.GROUP_KEY.value,
+                                          field_schema=PayloadSchemaType.KEYWORD,
+                                          field_type=PayloadSchemaType.KEYWORD)
+        # creat full text index
+        text_index_params = TextIndexParams(
+            type=TextIndexType.TEXT,
+            tokenizer=TokenizerType.MULTILINGUAL,
+            min_token_len=2,
+            max_token_len=20,
+            lowercase=True
+        )
+        self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value,
+                                          field_schema=text_index_params)
+
+    def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
+        uuids = self._get_uuids(documents)
+        texts = [d.page_content for d in documents]
+        metadatas = [d.metadata for d in documents]
+
+        added_ids = []
+        for batch_ids, points in self._generate_rest_batches(
+                texts, embeddings, metadatas, uuids, 64, self._group_id
+        ):
+            self._client.upsert(
+                collection_name=self._collection_name, points=points
+            )
+            added_ids.extend(batch_ids)
+
+        return added_ids
+
+    def _generate_rest_batches(
+            self,
+            texts: Iterable[str],
+            embeddings: list[list[float]],
+            metadatas: Optional[list[dict]] = None,
+            ids: Optional[Sequence[str]] = None,
+            batch_size: int = 64,
+            group_id: Optional[str] = None,
+    ) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]:
+        from qdrant_client.http import models as rest
+        texts_iterator = iter(texts)
+        embeddings_iterator = iter(embeddings)
+        metadatas_iterator = iter(metadatas or [])
+        ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
+        while batch_texts := list(islice(texts_iterator, batch_size)):
+            # Take the corresponding metadata and id for each text in a batch
+            batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
+            batch_ids = list(islice(ids_iterator, batch_size))
+
+            # Generate the embeddings for all the texts in a batch
+            batch_embeddings = list(islice(embeddings_iterator, batch_size))
+
+            points = [
+                rest.PointStruct(
+                    id=point_id,
+                    vector=vector,
+                    payload=payload,
+                )
+                for point_id, vector, payload in zip(
+                    batch_ids,
+                    batch_embeddings,
+                    self._build_payloads(
+                        batch_texts,
+                        batch_metadatas,
+                        Field.CONTENT_KEY.value,
+                        Field.METADATA_KEY.value,
+                        group_id,
+                        Field.GROUP_KEY.value,
+                    ),
+                )
+            ]
+
+            yield batch_ids, points
+
+    @classmethod
+    def _build_payloads(
+            cls,
+            texts: Iterable[str],
+            metadatas: Optional[list[dict]],
+            content_payload_key: str,
+            metadata_payload_key: str,
+            group_id: str,
+            group_payload_key: str
+    ) -> list[dict]:
+        payloads = []
+        for i, text in enumerate(texts):
+            if text is None:
+                raise ValueError(
+                    "At least one of the texts is None. Please remove it before "
+                    "calling .from_texts or .add_texts on Qdrant instance."
+                )
+            metadata = metadatas[i] if metadatas is not None else None
+            payloads.append(
+                {
+                    content_payload_key: text,
+                    metadata_payload_key: metadata,
+                    group_payload_key: group_id
+                }
+            )
+
+        return payloads
+
+    def delete_by_metadata_field(self, key: str, value: str):
+
+        from qdrant_client.http import models
+
+        filter = models.Filter(
+            must=[
+                models.FieldCondition(
+                    key=f"metadata.{key}",
+                    match=models.MatchValue(value=value),
+                ),
+            ],
+        )
+
+        self._reload_if_needed()
+
+        self._client.delete(
+            collection_name=self._collection_name,
+            points_selector=FilterSelector(
+                filter=filter
+            ),
+        )
+
+    def delete(self):
+        from qdrant_client.http import models
+        filter = models.Filter(
+            must=[
+                models.FieldCondition(
+                    key="group_id",
+                    match=models.MatchValue(value=self._group_id),
+                ),
+            ],
+        )
+        self._client.delete(
+            collection_name=self._collection_name,
+            points_selector=FilterSelector(
+                filter=filter
+            ),
+        )
+
+    def delete_by_ids(self, ids: list[str]) -> None:
+
+        from qdrant_client.http import models
+        for node_id in ids:
+            filter = models.Filter(
+                must=[
+                    models.FieldCondition(
+                        key="metadata.doc_id",
+                        match=models.MatchValue(value=node_id),
+                    ),
+                ],
+            )
+            self._client.delete(
+                collection_name=self._collection_name,
+                points_selector=FilterSelector(
+                    filter=filter
+                ),
+            )
+
+    def text_exists(self, id: str) -> bool:
+        response = self._client.retrieve(
+            collection_name=self._collection_name,
+            ids=[id]
+        )
+
+        return len(response) > 0
+
+    def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
+        from qdrant_client.http import models
+        filter = models.Filter(
+            must=[
+                models.FieldCondition(
+                    key="group_id",
+                    match=models.MatchValue(value=self._group_id),
+                ),
+            ],
+        )
+        results = self._client.search(
+            collection_name=self._collection_name,
+            query_vector=query_vector,
+            query_filter=filter,
+            limit=kwargs.get("top_k", 4),
+            with_payload=True,
+            with_vectors=True,
+            score_threshold=kwargs.get("score_threshold", .0)
+        )
+        docs = []
+        for result in results:
+            metadata = result.payload.get(Field.METADATA_KEY.value) or {}
+            # duplicate check score threshold
+            score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
+            if result.score > score_threshold:
+                metadata['score'] = result.score
+                doc = Document(
+                    page_content=result.payload.get(Field.CONTENT_KEY.value),
+                    metadata=metadata,
+                )
+                docs.append(doc)
+        return docs
+
+    def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
+        """Return docs most similar by bm25.
+        Returns:
+            List of documents most similar to the query text and distance for each.
+        """
+        from qdrant_client.http import models
+        scroll_filter = models.Filter(
+            must=[
+                models.FieldCondition(
+                    key="group_id",
+                    match=models.MatchValue(value=self._group_id),
+                ),
+                models.FieldCondition(
+                    key="page_content",
+                    match=models.MatchText(text=query),
+                )
+            ]
+        )
+        response = self._client.scroll(
+            collection_name=self._collection_name,
+            scroll_filter=scroll_filter,
+            limit=kwargs.get('top_k', 2),
+            with_payload=True,
+            with_vectors=True
+
+        )
+        results = response[0]
+        documents = []
+        for result in results:
+            if result:
+                documents.append(self._document_from_scored_point(
+                    result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value
+                ))
+
+        return documents
+
+    def _reload_if_needed(self):
+        if isinstance(self._client, QdrantLocal):
+            self._client = cast(QdrantLocal, self._client)
+            self._client._load()
+
+    @classmethod
+    def _document_from_scored_point(
+            cls,
+            scored_point: Any,
+            content_payload_key: str,
+            metadata_payload_key: str,
+    ) -> Document:
+        return Document(
+            page_content=scored_point.payload.get(content_payload_key),
+            metadata=scored_point.payload.get(metadata_payload_key) or {},
+        )

+ 62 - 0
api/core/rag/datasource/vdb/vector_base.py

@@ -0,0 +1,62 @@
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from typing import Any
+
+from core.rag.models.document import Document
+
+
+class BaseVector(ABC):
+
+    def __init__(self, collection_name: str):
+        self._collection_name = collection_name
+
+    @abstractmethod
+    def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
+        raise NotImplementedError
+
+    @abstractmethod
+    def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
+        raise NotImplementedError
+
+    @abstractmethod
+    def text_exists(self, id: str) -> bool:
+        raise NotImplementedError
+
+    @abstractmethod
+    def delete_by_ids(self, ids: list[str]) -> None:
+        raise NotImplementedError
+
+    @abstractmethod
+    def delete_by_metadata_field(self, key: str, value: str) -> None:
+        raise NotImplementedError
+
+    @abstractmethod
+    def search_by_vector(
+            self,
+            query_vector: list[float],
+            **kwargs: Any
+    ) -> list[Document]:
+        raise NotImplementedError
+
+    @abstractmethod
+    def search_by_full_text(
+            self, query: str,
+            **kwargs: Any
+    ) -> list[Document]:
+        raise NotImplementedError
+
+    def delete(self) -> None:
+        raise NotImplementedError
+
+    def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
+        for text in texts:
+            doc_id = text.metadata['doc_id']
+            exists_duplicate_node = self.text_exists(doc_id)
+            if exists_duplicate_node:
+                texts.remove(text)
+
+        return texts
+
+    def _get_uuids(self, texts: list[Document]) -> list[str]:
+        return [text.metadata['doc_id'] for text in texts]

+ 171 - 0
api/core/rag/datasource/vdb/vector_factory.py

@@ -0,0 +1,171 @@
+from typing import Any, cast
+
+from flask import current_app
+
+from core.embedding.cached_embedding import CacheEmbedding
+from core.model_manager import ModelManager
+from core.model_runtime.entities.model_entities import ModelType
+from core.rag.datasource.entity.embedding import Embeddings
+from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.models.document import Document
+from extensions.ext_database import db
+from models.dataset import Dataset, DatasetCollectionBinding
+
+
+class Vector:
+    def __init__(self, dataset: Dataset, attributes: list = None):
+        if attributes is None:
+            attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
+        self._dataset = dataset
+        self._embeddings = self._get_embeddings()
+        self._attributes = attributes
+        self._vector_processor = self._init_vector()
+
+    def _init_vector(self) -> BaseVector:
+        config = cast(dict, current_app.config)
+        vector_type = config.get('VECTOR_STORE')
+
+        if self._dataset.index_struct_dict:
+            vector_type = self._dataset.index_struct_dict['type']
+
+        if not vector_type:
+            raise ValueError("Vector store must be specified.")
+
+        if vector_type == "weaviate":
+            from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
+            if self._dataset.index_struct_dict:
+                class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
+                collection_name = class_prefix
+            else:
+                dataset_id = self._dataset.id
+                collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
+            return WeaviateVector(
+                collection_name=collection_name,
+                config=WeaviateConfig(
+                    endpoint=config.get('WEAVIATE_ENDPOINT'),
+                    api_key=config.get('WEAVIATE_API_KEY'),
+                    batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
+                ),
+                attributes=self._attributes
+            )
+        elif vector_type == "qdrant":
+            from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector
+            if self._dataset.collection_binding_id:
+                dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
+                    filter(DatasetCollectionBinding.id == self._dataset.collection_binding_id). \
+                    one_or_none()
+                if dataset_collection_binding:
+                    collection_name = dataset_collection_binding.collection_name
+                else:
+                    raise ValueError('Dataset Collection Bindings is not exist!')
+            else:
+                if self._dataset.index_struct_dict:
+                    class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
+                    collection_name = class_prefix
+                else:
+                    dataset_id = self._dataset.id
+                    collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
+
+            return QdrantVector(
+                collection_name=collection_name,
+                group_id=self._dataset.id,
+                config=QdrantConfig(
+                    endpoint=config.get('QDRANT_URL'),
+                    api_key=config.get('QDRANT_API_KEY'),
+                    root_path=current_app.root_path,
+                    timeout=config.get('QDRANT_CLIENT_TIMEOUT')
+                )
+            )
+        elif vector_type == "milvus":
+            from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector
+            if self._dataset.index_struct_dict:
+                class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
+                collection_name = class_prefix
+            else:
+                dataset_id = self._dataset.id
+                collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
+            return MilvusVector(
+                collection_name=collection_name,
+                config=MilvusConfig(
+                    host=config.get('MILVUS_HOST'),
+                    port=config.get('MILVUS_PORT'),
+                    user=config.get('MILVUS_USER'),
+                    password=config.get('MILVUS_PASSWORD'),
+                    secure=config.get('MILVUS_SECURE'),
+                )
+            )
+        else:
+            raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
+
+    def create(self, texts: list = None, **kwargs):
+        if texts:
+            embeddings = self._embeddings.embed_documents([document.page_content for document in texts])
+            self._vector_processor.create(
+                texts=texts,
+                embeddings=embeddings,
+                **kwargs
+            )
+
+    def add_texts(self, documents: list[Document], **kwargs):
+        if kwargs.get('duplicate_check', False):
+            documents = self._filter_duplicate_texts(documents)
+        embeddings = self._embeddings.embed_documents([document.page_content for document in documents])
+        self._vector_processor.add_texts(
+            documents=documents,
+            embeddings=embeddings,
+            **kwargs
+        )
+
+    def text_exists(self, id: str) -> bool:
+        return self._vector_processor.text_exists(id)
+
+    def delete_by_ids(self, ids: list[str]) -> None:
+        self._vector_processor.delete_by_ids(ids)
+
+    def delete_by_metadata_field(self, key: str, value: str) -> None:
+        self._vector_processor.delete_by_metadata_field(key, value)
+
+    def search_by_vector(
+            self, query: str,
+            **kwargs: Any
+    ) -> list[Document]:
+        query_vector = self._embeddings.embed_query(query)
+        return self._vector_processor.search_by_vector(query_vector, **kwargs)
+
+    def search_by_full_text(
+            self, query: str,
+            **kwargs: Any
+    ) -> list[Document]:
+        return self._vector_processor.search_by_full_text(query, **kwargs)
+
+    def delete(self) -> None:
+        self._vector_processor.delete()
+
+    def _get_embeddings(self) -> Embeddings:
+        model_manager = ModelManager()
+
+        embedding_model = model_manager.get_model_instance(
+            tenant_id=self._dataset.tenant_id,
+            provider=self._dataset.embedding_model_provider,
+            model_type=ModelType.TEXT_EMBEDDING,
+            model=self._dataset.embedding_model
+
+        )
+        return CacheEmbedding(embedding_model)
+
+    def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
+        for text in texts:
+            doc_id = text.metadata['doc_id']
+            exists_duplicate_node = self.text_exists(doc_id)
+            if exists_duplicate_node:
+                texts.remove(text)
+
+        return texts
+
+    def __getattr__(self, name):
+        if self._vector_processor is not None:
+            method = getattr(self._vector_processor, name)
+            if callable(method):
+                return method
+
+        raise AttributeError(f"'vector_processor' object has no attribute '{name}'")

+ 0 - 0
api/core/rag/datasource/vdb/weaviate/__init__.py


+ 235 - 0
api/core/rag/datasource/vdb/weaviate/weaviate_vector.py

@@ -0,0 +1,235 @@
+import datetime
+from typing import Any, Optional
+
+import requests
+import weaviate
+from pydantic import BaseModel, root_validator
+
+from core.rag.datasource.vdb.field import Field
+from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.models.document import Document
+from models.dataset import Dataset
+
+
+class WeaviateConfig(BaseModel):
+    endpoint: str
+    api_key: Optional[str]
+    batch_size: int = 100
+
+    @root_validator()
+    def validate_config(cls, values: dict) -> dict:
+        if not values['endpoint']:
+            raise ValueError("config WEAVIATE_ENDPOINT is required")
+        return values
+
+
+class WeaviateVector(BaseVector):
+
+    def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
+        super().__init__(collection_name)
+        self._client = self._init_client(config)
+        self._attributes = attributes
+
+    def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
+        auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
+
+        weaviate.connect.connection.has_grpc = False
+
+        try:
+            client = weaviate.Client(
+                url=config.endpoint,
+                auth_client_secret=auth_config,
+                timeout_config=(5, 60),
+                startup_period=None
+            )
+        except requests.exceptions.ConnectionError:
+            raise ConnectionError("Vector database connection error")
+
+        client.batch.configure(
+            # `batch_size` takes an `int` value to enable auto-batching
+            # (`None` is used for manual batching)
+            batch_size=config.batch_size,
+            # dynamically update the `batch_size` based on import speed
+            dynamic=True,
+            # `timeout_retries` takes an `int` value to retry on time outs
+            timeout_retries=3,
+        )
+
+        return client
+
+    def get_type(self) -> str:
+        return 'weaviate'
+
+    def get_collection_name(self, dataset: Dataset) -> str:
+        if dataset.index_struct_dict:
+            class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
+            if not class_prefix.endswith('_Node'):
+                # original class_prefix
+                class_prefix += '_Node'
+
+            return class_prefix
+
+        dataset_id = dataset.id
+        return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
+
+    def to_index_struct(self) -> dict:
+        return {
+            "type": self.get_type(),
+            "vector_store": {"class_prefix": self._collection_name}
+        }
+
+    def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
+
+        schema = self._default_schema(self._collection_name)
+
+        # check whether the index already exists
+        if not self._client.schema.contains(schema):
+            # create collection
+            self._client.schema.create_class(schema)
+        # create vector
+        self.add_texts(texts, embeddings)
+
+    def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
+        uuids = self._get_uuids(documents)
+        texts = [d.page_content for d in documents]
+        metadatas = [d.metadata for d in documents]
+
+        ids = []
+
+        with self._client.batch as batch:
+            for i, text in enumerate(texts):
+                data_properties = {Field.TEXT_KEY.value: text}
+                if metadatas is not None:
+                    for key, val in metadatas[i].items():
+                        data_properties[key] = self._json_serializable(val)
+
+                batch.add_data_object(
+                    data_object=data_properties,
+                    class_name=self._collection_name,
+                    uuid=uuids[i],
+                    vector=embeddings[i] if embeddings else None,
+                )
+                ids.append(uuids[i])
+        return ids
+
+    def delete_by_metadata_field(self, key: str, value: str):
+
+        where_filter = {
+            "operator": "Equal",
+            "path": [key],
+            "valueText": value
+        }
+
+        self._client.batch.delete_objects(
+            class_name=self._collection_name,
+            where=where_filter,
+            output='minimal'
+        )
+
+    def delete(self):
+        self._client.schema.delete_class(self._collection_name)
+
+    def text_exists(self, id: str) -> bool:
+        collection_name = self._collection_name
+        result = self._client.query.get(collection_name).with_additional(["id"]).with_where({
+            "path": ["doc_id"],
+            "operator": "Equal",
+            "valueText": id,
+        }).with_limit(1).do()
+
+        if "errors" in result:
+            raise ValueError(f"Error during query: {result['errors']}")
+
+        entries = result["data"]["Get"][collection_name]
+        if len(entries) == 0:
+            return False
+
+        return True
+
+    def delete_by_ids(self, ids: list[str]) -> None:
+        self._client.data_object.delete(
+            ids,
+            class_name=self._collection_name
+        )
+
+    def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
+        """Look up similar documents by embedding vector in Weaviate."""
+        collection_name = self._collection_name
+        properties = self._attributes
+        properties.append(Field.TEXT_KEY.value)
+        query_obj = self._client.query.get(collection_name, properties)
+
+        vector = {"vector": query_vector}
+        if kwargs.get("where_filter"):
+            query_obj = query_obj.with_where(kwargs.get("where_filter"))
+        result = (
+            query_obj.with_near_vector(vector)
+            .with_limit(kwargs.get("top_k", 4))
+            .with_additional(["vector", "distance"])
+            .do()
+        )
+        if "errors" in result:
+            raise ValueError(f"Error during query: {result['errors']}")
+
+        docs_and_scores = []
+        for res in result["data"]["Get"][collection_name]:
+            text = res.pop(Field.TEXT_KEY.value)
+            score = 1 - res["_additional"]["distance"]
+            docs_and_scores.append((Document(page_content=text, metadata=res), score))
+
+        docs = []
+        for doc, score in docs_and_scores:
+            score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
+            # check score threshold
+            if score > score_threshold:
+                doc.metadata['score'] = score
+                docs.append(doc)
+
+        return docs
+
+    def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
+        """Return docs using BM25F.
+
+        Args:
+            query: Text to look up documents similar to.
+            k: Number of Documents to return. Defaults to 4.
+
+        Returns:
+            List of Documents most similar to the query.
+        """
+        collection_name = self._collection_name
+        content: dict[str, Any] = {"concepts": [query]}
+        properties = self._attributes
+        properties.append(Field.TEXT_KEY.value)
+        if kwargs.get("search_distance"):
+            content["certainty"] = kwargs.get("search_distance")
+        query_obj = self._client.query.get(collection_name, properties)
+        if kwargs.get("where_filter"):
+            query_obj = query_obj.with_where(kwargs.get("where_filter"))
+        if kwargs.get("additional"):
+            query_obj = query_obj.with_additional(kwargs.get("additional"))
+        properties = ['text']
+        result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do()
+        if "errors" in result:
+            raise ValueError(f"Error during query: {result['errors']}")
+        docs = []
+        for res in result["data"]["Get"][collection_name]:
+            text = res.pop(Field.TEXT_KEY.value)
+            docs.append(Document(page_content=text, metadata=res))
+        return docs
+
+    def _default_schema(self, index_name: str) -> dict:
+        return {
+            "class": index_name,
+            "properties": [
+                {
+                    "name": "text",
+                    "dataType": ["text"],
+                }
+            ],
+        }
+
+    def _json_serializable(self, value: Any) -> Any:
+        if isinstance(value, datetime.datetime):
+            return value.isoformat()
+        return value

+ 166 - 0
api/core/rag/extractor/blod/blod.py

@@ -0,0 +1,166 @@
+"""Schema for Blobs and Blob Loaders.
+
+The goal is to facilitate decoupling of content loading from content parsing code.
+
+In addition, content loading code should provide a lazy loading interface by default.
+"""
+from __future__ import annotations
+
+import contextlib
+import mimetypes
+from abc import ABC, abstractmethod
+from collections.abc import Generator, Iterable, Mapping
+from io import BufferedReader, BytesIO
+from pathlib import PurePath
+from typing import Any, Optional, Union
+
+from pydantic import BaseModel, root_validator
+
+PathLike = Union[str, PurePath]
+
+
+class Blob(BaseModel):
+    """A blob is used to represent raw data by either reference or value.
+
+    Provides an interface to materialize the blob in different representations, and
+    help to decouple the development of data loaders from the downstream parsing of
+    the raw data.
+
+    Inspired by: https://developer.mozilla.org/en-US/docs/Web/API/Blob
+    """
+
+    data: Union[bytes, str, None]  # Raw data
+    mimetype: Optional[str] = None  # Not to be confused with a file extension
+    encoding: str = "utf-8"  # Use utf-8 as default encoding, if decoding to string
+    # Location where the original content was found
+    # Represent location on the local file system
+    # Useful for situations where downstream code assumes it must work with file paths
+    # rather than in-memory content.
+    path: Optional[PathLike] = None
+
+    class Config:
+        arbitrary_types_allowed = True
+        frozen = True
+
+    @property
+    def source(self) -> Optional[str]:
+        """The source location of the blob as string if known otherwise none."""
+        return str(self.path) if self.path else None
+
+    @root_validator(pre=True)
+    def check_blob_is_valid(cls, values: Mapping[str, Any]) -> Mapping[str, Any]:
+        """Verify that either data or path is provided."""
+        if "data" not in values and "path" not in values:
+            raise ValueError("Either data or path must be provided")
+        return values
+
+    def as_string(self) -> str:
+        """Read data as a string."""
+        if self.data is None and self.path:
+            with open(str(self.path), encoding=self.encoding) as f:
+                return f.read()
+        elif isinstance(self.data, bytes):
+            return self.data.decode(self.encoding)
+        elif isinstance(self.data, str):
+            return self.data
+        else:
+            raise ValueError(f"Unable to get string for blob {self}")
+
+    def as_bytes(self) -> bytes:
+        """Read data as bytes."""
+        if isinstance(self.data, bytes):
+            return self.data
+        elif isinstance(self.data, str):
+            return self.data.encode(self.encoding)
+        elif self.data is None and self.path:
+            with open(str(self.path), "rb") as f:
+                return f.read()
+        else:
+            raise ValueError(f"Unable to get bytes for blob {self}")
+
+    @contextlib.contextmanager
+    def as_bytes_io(self) -> Generator[Union[BytesIO, BufferedReader], None, None]:
+        """Read data as a byte stream."""
+        if isinstance(self.data, bytes):
+            yield BytesIO(self.data)
+        elif self.data is None and self.path:
+            with open(str(self.path), "rb") as f:
+                yield f
+        else:
+            raise NotImplementedError(f"Unable to convert blob {self}")
+
+    @classmethod
+    def from_path(
+        cls,
+        path: PathLike,
+        *,
+        encoding: str = "utf-8",
+        mime_type: Optional[str] = None,
+        guess_type: bool = True,
+    ) -> Blob:
+        """Load the blob from a path like object.
+
+        Args:
+            path: path like object to file to be read
+            encoding: Encoding to use if decoding the bytes into a string
+            mime_type: if provided, will be set as the mime-type of the data
+            guess_type: If True, the mimetype will be guessed from the file extension,
+                        if a mime-type was not provided
+
+        Returns:
+            Blob instance
+        """
+        if mime_type is None and guess_type:
+            _mimetype = mimetypes.guess_type(path)[0] if guess_type else None
+        else:
+            _mimetype = mime_type
+        # We do not load the data immediately, instead we treat the blob as a
+        # reference to the underlying data.
+        return cls(data=None, mimetype=_mimetype, encoding=encoding, path=path)
+
+    @classmethod
+    def from_data(
+        cls,
+        data: Union[str, bytes],
+        *,
+        encoding: str = "utf-8",
+        mime_type: Optional[str] = None,
+        path: Optional[str] = None,
+    ) -> Blob:
+        """Initialize the blob from in-memory data.
+
+        Args:
+            data: the in-memory data associated with the blob
+            encoding: Encoding to use if decoding the bytes into a string
+            mime_type: if provided, will be set as the mime-type of the data
+            path: if provided, will be set as the source from which the data came
+
+        Returns:
+            Blob instance
+        """
+        return cls(data=data, mimetype=mime_type, encoding=encoding, path=path)
+
+    def __repr__(self) -> str:
+        """Define the blob representation."""
+        str_repr = f"Blob {id(self)}"
+        if self.source:
+            str_repr += f" {self.source}"
+        return str_repr
+
+
+class BlobLoader(ABC):
+    """Abstract interface for blob loaders implementation.
+
+    Implementer should be able to load raw content from a datasource system according
+    to some criteria and return the raw content lazily as a stream of blobs.
+    """
+
+    @abstractmethod
+    def yield_blobs(
+        self,
+    ) -> Iterable[Blob]:
+        """A lazy loader for raw data represented by LangChain's Blob object.
+
+        Returns:
+            A generator over blobs
+        """

+ 71 - 0
api/core/rag/extractor/csv_extractor.py

@@ -0,0 +1,71 @@
+"""Abstract interface for document loader implementations."""
+import csv
+from typing import Optional
+
+from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.models.document import Document
+
+
+class CSVExtractor(BaseExtractor):
+    """Load CSV files.
+
+
+    Args:
+        file_path: Path to the file to load.
+    """
+
+    def __init__(
+            self,
+            file_path: str,
+            encoding: Optional[str] = None,
+            autodetect_encoding: bool = False,
+            source_column: Optional[str] = None,
+            csv_args: Optional[dict] = None,
+    ):
+        """Initialize with file path."""
+        self._file_path = file_path
+        self._encoding = encoding
+        self._autodetect_encoding = autodetect_encoding
+        self.source_column = source_column
+        self.csv_args = csv_args or {}
+
+    def extract(self) -> list[Document]:
+        """Load data into document objects."""
+        try:
+            with open(self._file_path, newline="", encoding=self._encoding) as csvfile:
+                docs = self._read_from_file(csvfile)
+        except UnicodeDecodeError as e:
+            if self._autodetect_encoding:
+                detected_encodings = detect_filze_encodings(self._file_path)
+                for encoding in detected_encodings:
+                    try:
+                        with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile:
+                            docs = self._read_from_file(csvfile)
+                        break
+                    except UnicodeDecodeError:
+                        continue
+            else:
+                raise RuntimeError(f"Error loading {self._file_path}") from e
+
+        return docs
+
+    def _read_from_file(self, csvfile) -> list[Document]:
+        docs = []
+        csv_reader = csv.DictReader(csvfile, **self.csv_args)  # type: ignore
+        for i, row in enumerate(csv_reader):
+            content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
+            try:
+                source = (
+                    row[self.source_column]
+                    if self.source_column is not None
+                    else ''
+                )
+            except KeyError:
+                raise ValueError(
+                    f"Source column '{self.source_column}' not found in CSV file."
+                )
+            metadata = {"source": source, "row": i}
+            doc = Document(page_content=content, metadata=metadata)
+            docs.append(doc)
+
+        return docs

+ 6 - 0
api/core/rag/extractor/entity/datasource_type.py

@@ -0,0 +1,6 @@
+from enum import Enum
+
+
+class DatasourceType(Enum):
+    FILE = "upload_file"
+    NOTION = "notion_import"

+ 36 - 0
api/core/rag/extractor/entity/extract_setting.py

@@ -0,0 +1,36 @@
+from pydantic import BaseModel
+
+from models.dataset import Document
+from models.model import UploadFile
+
+
+class NotionInfo(BaseModel):
+    """
+    Notion import info.
+    """
+    notion_workspace_id: str
+    notion_obj_id: str
+    notion_page_type: str
+    document: Document = None
+
+    class Config:
+        arbitrary_types_allowed = True
+
+    def __init__(self, **data) -> None:
+        super().__init__(**data)
+
+
+class ExtractSetting(BaseModel):
+    """
+    Model class for provider response.
+    """
+    datasource_type: str
+    upload_file: UploadFile = None
+    notion_info: NotionInfo = None
+    document_model: str = None
+
+    class Config:
+        arbitrary_types_allowed = True
+
+    def __init__(self, **data) -> None:
+        super().__init__(**data)

+ 14 - 9
api/core/data_loader/loader/excel.py → api/core/rag/extractor/excel_extractor.py

@@ -1,14 +1,14 @@
-import logging
+"""Abstract interface for document loader implementations."""
+from typing import Optional
 
 
-from langchain.document_loaders.base import BaseLoader
-from langchain.schema import Document
 from openpyxl.reader.excel import load_workbook
 from openpyxl.reader.excel import load_workbook
 
 
-logger = logging.getLogger(__name__)
+from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.models.document import Document
 
 
 
 
-class ExcelLoader(BaseLoader):
-    """Load xlxs files.
+class ExcelExtractor(BaseExtractor):
+    """Load Excel files.
 
 
 
 
     Args:
     Args:
@@ -16,13 +16,18 @@ class ExcelLoader(BaseLoader):
     """
     """
 
 
     def __init__(
     def __init__(
-        self,
-        file_path: str
+            self,
+            file_path: str,
+            encoding: Optional[str] = None,
+            autodetect_encoding: bool = False
     ):
     ):
         """Initialize with file path."""
         """Initialize with file path."""
         self._file_path = file_path
         self._file_path = file_path
+        self._encoding = encoding
+        self._autodetect_encoding = autodetect_encoding
 
 
-    def load(self) -> list[Document]:
+    def extract(self) -> list[Document]:
+        """Load from file path."""
         data = []
         data = []
         keys = []
         keys = []
         wb = load_workbook(filename=self._file_path, read_only=True)
         wb = load_workbook(filename=self._file_path, read_only=True)

+ 139 - 0
api/core/rag/extractor/extract_processor.py

@@ -0,0 +1,139 @@
+import tempfile
+from pathlib import Path
+from typing import Union
+
+import requests
+from flask import current_app
+
+from core.rag.extractor.csv_extractor import CSVExtractor
+from core.rag.extractor.entity.datasource_type import DatasourceType
+from core.rag.extractor.entity.extract_setting import ExtractSetting
+from core.rag.extractor.excel_extractor import ExcelExtractor
+from core.rag.extractor.html_extractor import HtmlExtractor
+from core.rag.extractor.markdown_extractor import MarkdownExtractor
+from core.rag.extractor.notion_extractor import NotionExtractor
+from core.rag.extractor.pdf_extractor import PdfExtractor
+from core.rag.extractor.text_extractor import TextExtractor
+from core.rag.extractor.unstructured.unstructured_doc_extractor import UnstructuredWordExtractor
+from core.rag.extractor.unstructured.unstructured_eml_extractor import UnstructuredEmailExtractor
+from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor
+from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor
+from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor
+from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor
+from core.rag.extractor.unstructured.unstructured_text_extractor import UnstructuredTextExtractor
+from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor
+from core.rag.extractor.word_extractor import WordExtractor
+from core.rag.models.document import Document
+from extensions.ext_storage import storage
+from models.model import UploadFile
+
+SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain']
+USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
+
+
+class ExtractProcessor:
+    @classmethod
+    def load_from_upload_file(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) \
+            -> Union[list[Document], str]:
+        extract_setting = ExtractSetting(
+            datasource_type="upload_file",
+            upload_file=upload_file,
+            document_model='text_model'
+        )
+        if return_text:
+            delimiter = '\n'
+            return delimiter.join([document.page_content for document in cls.extract(extract_setting, is_automatic)])
+        else:
+            return cls.extract(extract_setting, is_automatic)
+
+    @classmethod
+    def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]:
+        response = requests.get(url, headers={
+            "User-Agent": USER_AGENT
+        })
+
+        with tempfile.TemporaryDirectory() as temp_dir:
+            suffix = Path(url).suffix
+            file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
+            with open(file_path, 'wb') as file:
+                file.write(response.content)
+            extract_setting = ExtractSetting(
+                datasource_type="upload_file",
+                document_model='text_model'
+            )
+            if return_text:
+                delimiter = '\n'
+                return delimiter.join([document.page_content for document in cls.extract(
+                    extract_setting=extract_setting, file_path=file_path)])
+            else:
+                return cls.extract(extract_setting=extract_setting, file_path=file_path)
+
+    @classmethod
+    def extract(cls, extract_setting: ExtractSetting, is_automatic: bool = False,
+                file_path: str = None) -> list[Document]:
+        if extract_setting.datasource_type == DatasourceType.FILE.value:
+            with tempfile.TemporaryDirectory() as temp_dir:
+                if not file_path:
+                    upload_file: UploadFile = extract_setting.upload_file
+                    suffix = Path(upload_file.key).suffix
+                    file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
+                    storage.download(upload_file.key, file_path)
+                input_file = Path(file_path)
+                file_extension = input_file.suffix.lower()
+                etl_type = current_app.config['ETL_TYPE']
+                unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL']
+                if etl_type == 'Unstructured':
+                    if file_extension == '.xlsx':
+                        extractor = ExcelExtractor(file_path)
+                    elif file_extension == '.pdf':
+                        extractor = PdfExtractor(file_path)
+                    elif file_extension in ['.md', '.markdown']:
+                        extractor = UnstructuredMarkdownExtractor(file_path, unstructured_api_url) if is_automatic \
+                            else MarkdownExtractor(file_path, autodetect_encoding=True)
+                    elif file_extension in ['.htm', '.html']:
+                        extractor = HtmlExtractor(file_path)
+                    elif file_extension in ['.docx']:
+                        extractor = UnstructuredWordExtractor(file_path, unstructured_api_url)
+                    elif file_extension == '.csv':
+                        extractor = CSVExtractor(file_path, autodetect_encoding=True)
+                    elif file_extension == '.msg':
+                        extractor = UnstructuredMsgExtractor(file_path, unstructured_api_url)
+                    elif file_extension == '.eml':
+                        extractor = UnstructuredEmailExtractor(file_path, unstructured_api_url)
+                    elif file_extension == '.ppt':
+                        extractor = UnstructuredPPTExtractor(file_path, unstructured_api_url)
+                    elif file_extension == '.pptx':
+                        extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url)
+                    elif file_extension == '.xml':
+                        extractor = UnstructuredXmlExtractor(file_path, unstructured_api_url)
+                    else:
+                        # txt
+                        extractor = UnstructuredTextExtractor(file_path, unstructured_api_url) if is_automatic \
+                            else TextExtractor(file_path, autodetect_encoding=True)
+                else:
+                    if file_extension == '.xlsx':
+                        extractor = ExcelExtractor(file_path)
+                    elif file_extension == '.pdf':
+                        extractor = PdfExtractor(file_path)
+                    elif file_extension in ['.md', '.markdown']:
+                        extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
+                    elif file_extension in ['.htm', '.html']:
+                        extractor = HtmlExtractor(file_path)
+                    elif file_extension in ['.docx']:
+                        extractor = WordExtractor(file_path)
+                    elif file_extension == '.csv':
+                        extractor = CSVExtractor(file_path, autodetect_encoding=True)
+                    else:
+                        # txt
+                        extractor = TextExtractor(file_path, autodetect_encoding=True)
+                return extractor.extract()
+        elif extract_setting.datasource_type == DatasourceType.NOTION.value:
+            extractor = NotionExtractor(
+                notion_workspace_id=extract_setting.notion_info.notion_workspace_id,
+                notion_obj_id=extract_setting.notion_info.notion_obj_id,
+                notion_page_type=extract_setting.notion_info.notion_page_type,
+                document_model=extract_setting.notion_info.document
+            )
+            return extractor.extract()
+        else:
+            raise ValueError(f"Unsupported datasource type: {extract_setting.datasource_type}")

+ 12 - 0
api/core/rag/extractor/extractor_base.py

@@ -0,0 +1,12 @@
+"""Abstract interface for document loader implementations."""
+from abc import ABC, abstractmethod
+
+
+class BaseExtractor(ABC):
+    """Interface for extract files.
+    """
+
+    @abstractmethod
+    def extract(self):
+        raise NotImplementedError
+

+ 46 - 0
api/core/rag/extractor/helpers.py

@@ -0,0 +1,46 @@
+"""Document loader helpers."""
+
+import concurrent.futures
+from typing import NamedTuple, Optional, cast
+
+
+class FileEncoding(NamedTuple):
+    """A file encoding as the NamedTuple."""
+
+    encoding: Optional[str]
+    """The encoding of the file."""
+    confidence: float
+    """The confidence of the encoding."""
+    language: Optional[str]
+    """The language of the file."""
+
+
+def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding]:
+    """Try to detect the file encoding.
+
+    Returns a list of `FileEncoding` tuples with the detected encodings ordered
+    by confidence.
+
+    Args:
+        file_path: The path to the file to detect the encoding for.
+        timeout: The timeout in seconds for the encoding detection.
+    """
+    import chardet
+
+    def read_and_detect(file_path: str) -> list[dict]:
+        with open(file_path, "rb") as f:
+            rawdata = f.read()
+        return cast(list[dict], chardet.detect_all(rawdata))
+
+    with concurrent.futures.ThreadPoolExecutor() as executor:
+        future = executor.submit(read_and_detect, file_path)
+        try:
+            encodings = future.result(timeout=timeout)
+        except concurrent.futures.TimeoutError:
+            raise TimeoutError(
+                f"Timeout reached while detecting encoding for {file_path}"
+            )
+
+    if all(encoding["encoding"] is None for encoding in encodings):
+        raise RuntimeError(f"Could not detect encoding for {file_path}")
+    return [FileEncoding(**enc) for enc in encodings if enc["encoding"] is not None]

+ 24 - 20
api/core/data_loader/loader/csv_loader.py → api/core/rag/extractor/html_extractor.py

@@ -1,51 +1,55 @@
-import csv
-import logging
+"""Abstract interface for document loader implementations."""
 from typing import Optional
 from typing import Optional
 
 
-from langchain.document_loaders import CSVLoader as LCCSVLoader
-from langchain.document_loaders.helpers import detect_file_encodings
-from langchain.schema import Document
+from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.extractor.helpers import detect_file_encodings
+from core.rag.models.document import Document
 
 
-logger = logging.getLogger(__name__)
 
 
+class HtmlExtractor(BaseExtractor):
+    """Load html files.
+
+
+    Args:
+        file_path: Path to the file to load.
+    """
 
 
-class CSVLoader(LCCSVLoader):
     def __init__(
     def __init__(
             self,
             self,
             file_path: str,
             file_path: str,
+            encoding: Optional[str] = None,
+            autodetect_encoding: bool = False,
             source_column: Optional[str] = None,
             source_column: Optional[str] = None,
             csv_args: Optional[dict] = None,
             csv_args: Optional[dict] = None,
-            encoding: Optional[str] = None,
-            autodetect_encoding: bool = True,
     ):
     ):
-        self.file_path = file_path
+        """Initialize with file path."""
+        self._file_path = file_path
+        self._encoding = encoding
+        self._autodetect_encoding = autodetect_encoding
         self.source_column = source_column
         self.source_column = source_column
-        self.encoding = encoding
         self.csv_args = csv_args or {}
         self.csv_args = csv_args or {}
-        self.autodetect_encoding = autodetect_encoding
 
 
-    def load(self) -> list[Document]:
+    def extract(self) -> list[Document]:
         """Load data into document objects."""
         """Load data into document objects."""
         try:
         try:
-            with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
+            with open(self._file_path, newline="", encoding=self._encoding) as csvfile:
                 docs = self._read_from_file(csvfile)
                 docs = self._read_from_file(csvfile)
         except UnicodeDecodeError as e:
         except UnicodeDecodeError as e:
-            if self.autodetect_encoding:
-                detected_encodings = detect_file_encodings(self.file_path)
+            if self._autodetect_encoding:
+                detected_encodings = detect_file_encodings(self._file_path)
                 for encoding in detected_encodings:
                 for encoding in detected_encodings:
-                    logger.debug("Trying encoding: ", encoding.encoding)
                     try:
                     try:
-                        with open(self.file_path, newline="", encoding=encoding.encoding) as csvfile:
+                        with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile:
                             docs = self._read_from_file(csvfile)
                             docs = self._read_from_file(csvfile)
                         break
                         break
                     except UnicodeDecodeError:
                     except UnicodeDecodeError:
                         continue
                         continue
             else:
             else:
-                raise RuntimeError(f"Error loading {self.file_path}") from e
+                raise RuntimeError(f"Error loading {self._file_path}") from e
 
 
         return docs
         return docs
 
 
-    def _read_from_file(self, csvfile):
+    def _read_from_file(self, csvfile) -> list[Document]:
         docs = []
         docs = []
         csv_reader = csv.DictReader(csvfile, **self.csv_args)  # type: ignore
         csv_reader = csv.DictReader(csvfile, **self.csv_args)  # type: ignore
         for i, row in enumerate(csv_reader):
         for i, row in enumerate(csv_reader):

+ 14 - 26
api/core/data_loader/loader/markdown.py → api/core/rag/extractor/markdown_extractor.py

@@ -1,39 +1,27 @@
-import logging
+"""Abstract interface for document loader implementations."""
 import re
 import re
 from typing import Optional, cast
 from typing import Optional, cast
 
 
-from langchain.document_loaders.base import BaseLoader
-from langchain.document_loaders.helpers import detect_file_encodings
-from langchain.schema import Document
+from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.extractor.helpers import detect_file_encodings
+from core.rag.models.document import Document
 
 
-logger = logging.getLogger(__name__)
 
 
-
-class MarkdownLoader(BaseLoader):
-    """Load md files.
+class MarkdownExtractor(BaseExtractor):
+    """Load Markdown files.
 
 
 
 
     Args:
     Args:
         file_path: Path to the file to load.
         file_path: Path to the file to load.
-
-        remove_hyperlinks: Whether to remove hyperlinks from the text.
-
-        remove_images: Whether to remove images from the text.
-
-        encoding: File encoding to use. If `None`, the file will be loaded
-        with the default system encoding.
-
-        autodetect_encoding: Whether to try to autodetect the file encoding
-            if the specified encoding fails.
     """
     """
 
 
     def __init__(
     def __init__(
-        self,
-        file_path: str,
-        remove_hyperlinks: bool = True,
-        remove_images: bool = True,
-        encoding: Optional[str] = None,
-        autodetect_encoding: bool = True,
+            self,
+            file_path: str,
+            remove_hyperlinks: bool = True,
+            remove_images: bool = True,
+            encoding: Optional[str] = None,
+            autodetect_encoding: bool = True,
     ):
     ):
         """Initialize with file path."""
         """Initialize with file path."""
         self._file_path = file_path
         self._file_path = file_path
@@ -42,7 +30,8 @@ class MarkdownLoader(BaseLoader):
         self._encoding = encoding
         self._encoding = encoding
         self._autodetect_encoding = autodetect_encoding
         self._autodetect_encoding = autodetect_encoding
 
 
-    def load(self) -> list[Document]:
+    def extract(self) -> list[Document]:
+        """Load from file path."""
         tups = self.parse_tups(self._file_path)
         tups = self.parse_tups(self._file_path)
         documents = []
         documents = []
         for header, value in tups:
         for header, value in tups:
@@ -113,7 +102,6 @@ class MarkdownLoader(BaseLoader):
             if self._autodetect_encoding:
             if self._autodetect_encoding:
                 detected_encodings = detect_file_encodings(filepath)
                 detected_encodings = detect_file_encodings(filepath)
                 for encoding in detected_encodings:
                 for encoding in detected_encodings:
-                    logger.debug("Trying encoding: ", encoding.encoding)
                     try:
                     try:
                         with open(filepath, encoding=encoding.encoding) as f:
                         with open(filepath, encoding=encoding.encoding) as f:
                             content = f.read()
                             content = f.read()

+ 23 - 37
api/core/data_loader/loader/notion.py → api/core/rag/extractor/notion_extractor.py

@@ -4,9 +4,10 @@ from typing import Any, Optional
 
 
 import requests
 import requests
 from flask import current_app
 from flask import current_app
-from langchain.document_loaders.base import BaseLoader
+from flask_login import current_user
 from langchain.schema import Document
 from langchain.schema import Document
 
 
+from core.rag.extractor.extractor_base import BaseExtractor
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.dataset import Document as DocumentModel
 from models.dataset import Document as DocumentModel
 from models.source import DataSourceBinding
 from models.source import DataSourceBinding
@@ -22,52 +23,37 @@ RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
 HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']
 HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']
 
 
 
 
-class NotionLoader(BaseLoader):
+class NotionExtractor(BaseExtractor):
+
     def __init__(
     def __init__(
             self,
             self,
-            notion_access_token: str,
             notion_workspace_id: str,
             notion_workspace_id: str,
             notion_obj_id: str,
             notion_obj_id: str,
             notion_page_type: str,
             notion_page_type: str,
-            document_model: Optional[DocumentModel] = None
+            document_model: Optional[DocumentModel] = None,
+            notion_access_token: Optional[str] = None
     ):
     ):
+        self._notion_access_token = None
         self._document_model = document_model
         self._document_model = document_model
         self._notion_workspace_id = notion_workspace_id
         self._notion_workspace_id = notion_workspace_id
         self._notion_obj_id = notion_obj_id
         self._notion_obj_id = notion_obj_id
         self._notion_page_type = notion_page_type
         self._notion_page_type = notion_page_type
-        self._notion_access_token = notion_access_token
-
-        if not self._notion_access_token:
-            integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN')
-            if integration_token is None:
-                raise ValueError(
-                    "Must specify `integration_token` or set environment "
-                    "variable `NOTION_INTEGRATION_TOKEN`."
-                )
-
-            self._notion_access_token = integration_token
-
-    @classmethod
-    def from_document(cls, document_model: DocumentModel):
-        data_source_info = document_model.data_source_info_dict
-        if not data_source_info or 'notion_page_id' not in data_source_info \
-                or 'notion_workspace_id' not in data_source_info:
-            raise ValueError("no notion page found")
-
-        notion_workspace_id = data_source_info['notion_workspace_id']
-        notion_obj_id = data_source_info['notion_page_id']
-        notion_page_type = data_source_info['type']
-        notion_access_token = cls._get_access_token(document_model.tenant_id, notion_workspace_id)
-
-        return cls(
-            notion_access_token=notion_access_token,
-            notion_workspace_id=notion_workspace_id,
-            notion_obj_id=notion_obj_id,
-            notion_page_type=notion_page_type,
-            document_model=document_model
-        )
-
-    def load(self) -> list[Document]:
+        if notion_access_token:
+            self._notion_access_token = notion_access_token
+        else:
+            self._notion_access_token = self._get_access_token(current_user.current_tenant_id,
+                                                               self._notion_workspace_id)
+            if not self._notion_access_token:
+                integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN')
+                if integration_token is None:
+                    raise ValueError(
+                        "Must specify `integration_token` or set environment "
+                        "variable `NOTION_INTEGRATION_TOKEN`."
+                    )
+
+                self._notion_access_token = integration_token
+
+    def extract(self) -> list[Document]:
         self.update_last_edited_time(
         self.update_last_edited_time(
             self._document_model
             self._document_model
         )
         )

+ 72 - 0
api/core/rag/extractor/pdf_extractor.py

@@ -0,0 +1,72 @@
+"""Abstract interface for document loader implementations."""
+from collections.abc import Iterator
+from typing import Optional
+
+from core.rag.extractor.blod.blod import Blob
+from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.models.document import Document
+from extensions.ext_storage import storage
+
+
+class PdfExtractor(BaseExtractor):
+    """Load pdf files.
+
+
+    Args:
+        file_path: Path to the file to load.
+    """
+
+    def __init__(
+            self,
+            file_path: str,
+            file_cache_key: Optional[str] = None
+    ):
+        """Initialize with file path."""
+        self._file_path = file_path
+        self._file_cache_key = file_cache_key
+
+    def extract(self) -> list[Document]:
+        plaintext_file_key = ''
+        plaintext_file_exists = False
+        if self._file_cache_key:
+            try:
+                text = storage.load(self._file_cache_key).decode('utf-8')
+                plaintext_file_exists = True
+                return [Document(page_content=text)]
+            except FileNotFoundError:
+                pass
+        documents = list(self.load())
+        text_list = []
+        for document in documents:
+            text_list.append(document.page_content)
+        text = "\n\n".join(text_list)
+
+        # save plaintext file for caching
+        if not plaintext_file_exists and plaintext_file_key:
+            storage.save(plaintext_file_key, text.encode('utf-8'))
+
+        return documents
+
+    def load(
+            self,
+    ) -> Iterator[Document]:
+        """Lazy load given path as pages."""
+        blob = Blob.from_path(self._file_path)
+        yield from self.parse(blob)
+
+    def parse(self, blob: Blob) -> Iterator[Document]:
+        """Lazily parse the blob."""
+        import pypdfium2
+
+        with blob.as_bytes_io() as file_path:
+            pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
+            try:
+                for page_number, page in enumerate(pdf_reader):
+                    text_page = page.get_textpage()
+                    content = text_page.get_text_range()
+                    text_page.close()
+                    page.close()
+                    metadata = {"source": blob.source, "page": page_number}
+                    yield Document(page_content=content, metadata=metadata)
+            finally:
+                pdf_reader.close()

+ 50 - 0
api/core/rag/extractor/text_extractor.py

@@ -0,0 +1,50 @@
+"""Abstract interface for document loader implementations."""
+from typing import Optional
+
+from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.extractor.helpers import detect_file_encodings
+from core.rag.models.document import Document
+
+
+class TextExtractor(BaseExtractor):
+    """Load text files.
+
+
+    Args:
+        file_path: Path to the file to load.
+    """
+
+    def __init__(
+            self,
+            file_path: str,
+            encoding: Optional[str] = None,
+            autodetect_encoding: bool = False
+    ):
+        """Initialize with file path."""
+        self._file_path = file_path
+        self._encoding = encoding
+        self._autodetect_encoding = autodetect_encoding
+
+    def extract(self) -> list[Document]:
+        """Load from file path."""
+        text = ""
+        try:
+            with open(self._file_path, encoding=self._encoding) as f:
+                text = f.read()
+        except UnicodeDecodeError as e:
+            if self._autodetect_encoding:
+                detected_encodings = detect_file_encodings(self._file_path)
+                for encoding in detected_encodings:
+                    try:
+                        with open(self._file_path, encoding=encoding.encoding) as f:
+                            text = f.read()
+                        break
+                    except UnicodeDecodeError:
+                        continue
+            else:
+                raise RuntimeError(f"Error loading {self._file_path}") from e
+        except Exception as e:
+            raise RuntimeError(f"Error loading {self._file_path}") from e
+
+        metadata = {"source": self._file_path}
+        return [Document(page_content=text, metadata=metadata)]

+ 61 - 0
api/core/rag/extractor/unstructured/unstructured_doc_extractor.py

@@ -0,0 +1,61 @@
+import logging
+import os
+
+from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.models.document import Document
+
+logger = logging.getLogger(__name__)
+
+
+class UnstructuredWordExtractor(BaseExtractor):
+    """Loader that uses unstructured to load word documents.
+    """
+
+    def __init__(
+            self,
+            file_path: str,
+            api_url: str,
+    ):
+        """Initialize with file path."""
+        self._file_path = file_path
+        self._api_url = api_url
+
+    def extract(self) -> list[Document]:
+        from unstructured.__version__ import __version__ as __unstructured_version__
+        from unstructured.file_utils.filetype import FileType, detect_filetype
+
+        unstructured_version = tuple(
+            [int(x) for x in __unstructured_version__.split(".")]
+        )
+        # check the file extension
+        try:
+            import magic  # noqa: F401
+
+            is_doc = detect_filetype(self._file_path) == FileType.DOC
+        except ImportError:
+            _, extension = os.path.splitext(str(self._file_path))
+            is_doc = extension == ".doc"
+
+        if is_doc and unstructured_version < (0, 4, 11):
+            raise ValueError(
+                f"You are on unstructured version {__unstructured_version__}. "
+                "Partitioning .doc files is only supported in unstructured>=0.4.11. "
+                "Please upgrade the unstructured package and try again."
+            )
+
+        if is_doc:
+            from unstructured.partition.doc import partition_doc
+
+            elements = partition_doc(filename=self._file_path)
+        else:
+            from unstructured.partition.docx import partition_docx
+
+            elements = partition_docx(filename=self._file_path)
+
+        from unstructured.chunking.title import chunk_by_title
+        chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
+        documents = []
+        for chunk in chunks:
+            text = chunk.text.strip()
+            documents.append(Document(page_content=text))
+        return documents

+ 5 - 4
api/core/data_loader/loader/unstructured/unstructured_eml.py → api/core/rag/extractor/unstructured/unstructured_eml_extractor.py

@@ -2,13 +2,14 @@ import base64
 import logging
 import logging
 
 
 from bs4 import BeautifulSoup
 from bs4 import BeautifulSoup
-from langchain.document_loaders.base import BaseLoader
-from langchain.schema import Document
+
+from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.models.document import Document
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
 
 
-class UnstructuredEmailLoader(BaseLoader):
+class UnstructuredEmailExtractor(BaseExtractor):
     """Load msg files.
     """Load msg files.
     Args:
     Args:
         file_path: Path to the file to load.
         file_path: Path to the file to load.
@@ -23,7 +24,7 @@ class UnstructuredEmailLoader(BaseLoader):
         self._file_path = file_path
         self._file_path = file_path
         self._api_url = api_url
         self._api_url = api_url
 
 
-    def load(self) -> list[Document]:
+    def extract(self) -> list[Document]:
         from unstructured.partition.email import partition_email
         from unstructured.partition.email import partition_email
         elements = partition_email(filename=self._file_path, api_url=self._api_url)
         elements = partition_email(filename=self._file_path, api_url=self._api_url)
 
 

+ 4 - 4
api/core/data_loader/loader/unstructured/unstructured_markdown.py → api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py

@@ -1,12 +1,12 @@
 import logging
 import logging
 
 
-from langchain.document_loaders.base import BaseLoader
-from langchain.schema import Document
+from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.models.document import Document
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
 
 
-class UnstructuredMarkdownLoader(BaseLoader):
+class UnstructuredMarkdownExtractor(BaseExtractor):
     """Load md files.
     """Load md files.
 
 
 
 
@@ -33,7 +33,7 @@ class UnstructuredMarkdownLoader(BaseLoader):
         self._file_path = file_path
         self._file_path = file_path
         self._api_url = api_url
         self._api_url = api_url
 
 
-    def load(self) -> list[Document]:
+    def extract(self) -> list[Document]:
         from unstructured.partition.md import partition_md
         from unstructured.partition.md import partition_md
 
 
         elements = partition_md(filename=self._file_path, api_url=self._api_url)
         elements = partition_md(filename=self._file_path, api_url=self._api_url)

+ 4 - 4
api/core/data_loader/loader/unstructured/unstructured_msg.py → api/core/rag/extractor/unstructured/unstructured_msg_extractor.py

@@ -1,12 +1,12 @@
 import logging
 import logging
 
 
-from langchain.document_loaders.base import BaseLoader
-from langchain.schema import Document
+from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.models.document import Document
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
 
 
-class UnstructuredMsgLoader(BaseLoader):
+class UnstructuredMsgExtractor(BaseExtractor):
     """Load msg files.
     """Load msg files.
 
 
 
 
@@ -23,7 +23,7 @@ class UnstructuredMsgLoader(BaseLoader):
         self._file_path = file_path
         self._file_path = file_path
         self._api_url = api_url
         self._api_url = api_url
 
 
-    def load(self) -> list[Document]:
+    def extract(self) -> list[Document]:
         from unstructured.partition.msg import partition_msg
         from unstructured.partition.msg import partition_msg
 
 
         elements = partition_msg(filename=self._file_path, api_url=self._api_url)
         elements = partition_msg(filename=self._file_path, api_url=self._api_url)

+ 8 - 7
api/core/data_loader/loader/unstructured/unstructured_ppt.py → api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py

@@ -1,11 +1,12 @@
 import logging
 import logging
 
 
-from langchain.document_loaders.base import BaseLoader
-from langchain.schema import Document
+from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.models.document import Document
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
-class UnstructuredPPTLoader(BaseLoader):
+
+class UnstructuredPPTExtractor(BaseExtractor):
     """Load msg files.
     """Load msg files.
 
 
 
 
@@ -14,15 +15,15 @@ class UnstructuredPPTLoader(BaseLoader):
     """
     """
 
 
     def __init__(
     def __init__(
-        self,
-        file_path: str,
-        api_url: str
+            self,
+            file_path: str,
+            api_url: str
     ):
     ):
         """Initialize with file path."""
         """Initialize with file path."""
         self._file_path = file_path
         self._file_path = file_path
         self._api_url = api_url
         self._api_url = api_url
 
 
-    def load(self) -> list[Document]:
+    def extract(self) -> list[Document]:
         from unstructured.partition.ppt import partition_ppt
         from unstructured.partition.ppt import partition_ppt
 
 
         elements = partition_ppt(filename=self._file_path, api_url=self._api_url)
         elements = partition_ppt(filename=self._file_path, api_url=self._api_url)

+ 9 - 7
api/core/data_loader/loader/unstructured/unstructured_pptx.py → api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py

@@ -1,10 +1,12 @@
 import logging
 import logging
 
 
-from langchain.document_loaders.base import BaseLoader
-from langchain.schema import Document
+from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.models.document import Document
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
-class UnstructuredPPTXLoader(BaseLoader):
+
+
+class UnstructuredPPTXExtractor(BaseExtractor):
     """Load msg files.
     """Load msg files.
 
 
 
 
@@ -13,15 +15,15 @@ class UnstructuredPPTXLoader(BaseLoader):
     """
     """
 
 
     def __init__(
     def __init__(
-        self,
-        file_path: str,
-        api_url: str
+            self,
+            file_path: str,
+            api_url: str
     ):
     ):
         """Initialize with file path."""
         """Initialize with file path."""
         self._file_path = file_path
         self._file_path = file_path
         self._api_url = api_url
         self._api_url = api_url
 
 
-    def load(self) -> list[Document]:
+    def extract(self) -> list[Document]:
         from unstructured.partition.pptx import partition_pptx
         from unstructured.partition.pptx import partition_pptx
 
 
         elements = partition_pptx(filename=self._file_path, api_url=self._api_url)
         elements = partition_pptx(filename=self._file_path, api_url=self._api_url)

+ 4 - 4
api/core/data_loader/loader/unstructured/unstructured_text.py → api/core/rag/extractor/unstructured/unstructured_text_extractor.py

@@ -1,12 +1,12 @@
 import logging
 import logging
 
 
-from langchain.document_loaders.base import BaseLoader
-from langchain.schema import Document
+from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.models.document import Document
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
 
 
-class UnstructuredTextLoader(BaseLoader):
+class UnstructuredTextExtractor(BaseExtractor):
     """Load msg files.
     """Load msg files.
 
 
 
 
@@ -23,7 +23,7 @@ class UnstructuredTextLoader(BaseLoader):
         self._file_path = file_path
         self._file_path = file_path
         self._api_url = api_url
         self._api_url = api_url
 
 
-    def load(self) -> list[Document]:
+    def extract(self) -> list[Document]:
         from unstructured.partition.text import partition_text
         from unstructured.partition.text import partition_text
 
 
         elements = partition_text(filename=self._file_path, api_url=self._api_url)
         elements = partition_text(filename=self._file_path, api_url=self._api_url)

+ 4 - 4
api/core/data_loader/loader/unstructured/unstructured_xml.py → api/core/rag/extractor/unstructured/unstructured_xml_extractor.py

@@ -1,12 +1,12 @@
 import logging
 import logging
 
 
-from langchain.document_loaders.base import BaseLoader
-from langchain.schema import Document
+from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.models.document import Document
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
 
 
-class UnstructuredXmlLoader(BaseLoader):
+class UnstructuredXmlExtractor(BaseExtractor):
     """Load msg files.
     """Load msg files.
 
 
 
 
@@ -23,7 +23,7 @@ class UnstructuredXmlLoader(BaseLoader):
         self._file_path = file_path
         self._file_path = file_path
         self._api_url = api_url
         self._api_url = api_url
 
 
-    def load(self) -> list[Document]:
+    def extract(self) -> list[Document]:
         from unstructured.partition.xml import partition_xml
         from unstructured.partition.xml import partition_xml
 
 
         elements = partition_xml(filename=self._file_path, xml_keep_tags=True, api_url=self._api_url)
         elements = partition_xml(filename=self._file_path, xml_keep_tags=True, api_url=self._api_url)

+ 62 - 0
api/core/rag/extractor/word_extractor.py

@@ -0,0 +1,62 @@
+"""Abstract interface for document loader implementations."""
+import os
+import tempfile
+from urllib.parse import urlparse
+
+import requests
+
+from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.models.document import Document
+
+
+class WordExtractor(BaseExtractor):
+    """Load pdf files.
+
+
+    Args:
+        file_path: Path to the file to load.
+    """
+
+    def __init__(self, file_path: str):
+        """Initialize with file path."""
+        self.file_path = file_path
+        if "~" in self.file_path:
+            self.file_path = os.path.expanduser(self.file_path)
+
+        # If the file is a web path, download it to a temporary file, and use that
+        if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path):
+            r = requests.get(self.file_path)
+
+            if r.status_code != 200:
+                raise ValueError(
+                    "Check the url of your file; returned status code %s"
+                    % r.status_code
+                )
+
+            self.web_path = self.file_path
+            self.temp_file = tempfile.NamedTemporaryFile()
+            self.temp_file.write(r.content)
+            self.file_path = self.temp_file.name
+        elif not os.path.isfile(self.file_path):
+            raise ValueError("File path %s is not a valid file or url" % self.file_path)
+
+    def __del__(self) -> None:
+        if hasattr(self, "temp_file"):
+            self.temp_file.close()
+
+    def extract(self) -> list[Document]:
+        """Load given path as single page."""
+        import docx2txt
+
+        return [
+            Document(
+                page_content=docx2txt.process(self.file_path),
+                metadata={"source": self.file_path},
+            )
+        ]
+
+    @staticmethod
+    def _is_valid_url(url: str) -> bool:
+        """Check if the url is valid."""
+        parsed = urlparse(url)
+        return bool(parsed.netloc) and bool(parsed.scheme)

+ 0 - 0
api/core/rag/index_processor/__init__.py


+ 0 - 0
api/core/rag/index_processor/constant/__init__.py


+ 8 - 0
api/core/rag/index_processor/constant/index_type.py

@@ -0,0 +1,8 @@
+from enum import Enum
+
+
+class IndexType(Enum):
+    PARAGRAPH_INDEX = "text_model"
+    QA_INDEX = "qa_model"
+    PARENT_CHILD_INDEX = "parent_child_index"
+    SUMMARY_INDEX = "summary_index"

+ 70 - 0
api/core/rag/index_processor/index_processor_base.py

@@ -0,0 +1,70 @@
+"""Abstract interface for document loader implementations."""
+from abc import ABC, abstractmethod
+from typing import Optional
+
+from langchain.text_splitter import TextSplitter
+
+from core.model_manager import ModelInstance
+from core.rag.extractor.entity.extract_setting import ExtractSetting
+from core.rag.models.document import Document
+from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter
+from models.dataset import Dataset, DatasetProcessRule
+
+
+class BaseIndexProcessor(ABC):
+    """Interface for extract files.
+    """
+
+    @abstractmethod
+    def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
+        raise NotImplementedError
+
+    @abstractmethod
+    def transform(self, documents: list[Document], **kwargs) -> list[Document]:
+        raise NotImplementedError
+
+    @abstractmethod
+    def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
+        raise NotImplementedError
+
+    def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
+        raise NotImplementedError
+
+    @abstractmethod
+    def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
+                 score_threshold: float, reranking_model: dict) -> list[Document]:
+        raise NotImplementedError
+
+    def _get_splitter(self, processing_rule: dict,
+                      embedding_model_instance: Optional[ModelInstance]) -> TextSplitter:
+        """
+        Get the NodeParser object according to the processing rule.
+        """
+        if processing_rule['mode'] == "custom":
+            # The user-defined segmentation rule
+            rules = processing_rule['rules']
+            segmentation = rules["segmentation"]
+            if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > 1000:
+                raise ValueError("Custom segment length should be between 50 and 1000.")
+
+            separator = segmentation["separator"]
+            if separator:
+                separator = separator.replace('\\n', '\n')
+
+            character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
+                chunk_size=segmentation["max_tokens"],
+                chunk_overlap=0,
+                fixed_separator=separator,
+                separators=["\n\n", "。", ".", " ", ""],
+                embedding_model_instance=embedding_model_instance
+            )
+        else:
+            # Automatic segmentation
+            character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
+                chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
+                chunk_overlap=0,
+                separators=["\n\n", "。", ".", " ", ""],
+                embedding_model_instance=embedding_model_instance
+            )
+
+        return character_splitter

+ 28 - 0
api/core/rag/index_processor/index_processor_factory.py

@@ -0,0 +1,28 @@
+"""Abstract interface for document loader implementations."""
+
+from core.rag.index_processor.constant.index_type import IndexType
+from core.rag.index_processor.index_processor_base import BaseIndexProcessor
+from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
+from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor
+
+
+class IndexProcessorFactory:
+    """IndexProcessorInit.
+    """
+
+    def __init__(self, index_type: str):
+        self._index_type = index_type
+
+    def init_index_processor(self) -> BaseIndexProcessor:
+        """Init index processor."""
+
+        if not self._index_type:
+            raise ValueError("Index type must be specified.")
+
+        if self._index_type == IndexType.PARAGRAPH_INDEX.value:
+            return ParagraphIndexProcessor()
+        elif self._index_type == IndexType.QA_INDEX.value:
+
+            return QAIndexProcessor()
+        else:
+            raise ValueError(f"Index type {self._index_type} is not supported.")

+ 0 - 0
api/core/rag/index_processor/processor/__init__.py


+ 92 - 0
api/core/rag/index_processor/processor/paragraph_index_processor.py

@@ -0,0 +1,92 @@
+"""Paragraph index processor."""
+import uuid
+from typing import Optional
+
+from core.rag.cleaner.clean_processor import CleanProcessor
+from core.rag.datasource.keyword.keyword_factory import Keyword
+from core.rag.datasource.retrieval_service import RetrievalService
+from core.rag.datasource.vdb.vector_factory import Vector
+from core.rag.extractor.entity.extract_setting import ExtractSetting
+from core.rag.extractor.extract_processor import ExtractProcessor
+from core.rag.index_processor.index_processor_base import BaseIndexProcessor
+from core.rag.models.document import Document
+from libs import helper
+from models.dataset import Dataset
+
+
+class ParagraphIndexProcessor(BaseIndexProcessor):
+
+    def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
+
+        text_docs = ExtractProcessor.extract(extract_setting=extract_setting,
+                                             is_automatic=kwargs.get('process_rule_mode') == "automatic")
+
+        return text_docs
+
+    def transform(self, documents: list[Document], **kwargs) -> list[Document]:
+        # Split the text documents into nodes.
+        splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'),
+                                      embedding_model_instance=kwargs.get('embedding_model_instance'))
+        all_documents = []
+        for document in documents:
+            # document clean
+            document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule'))
+            document.page_content = document_text
+            # parse document to nodes
+            document_nodes = splitter.split_documents([document])
+            split_documents = []
+            for document_node in document_nodes:
+
+                if document_node.page_content.strip():
+                    doc_id = str(uuid.uuid4())
+                    hash = helper.generate_text_hash(document_node.page_content)
+                    document_node.metadata['doc_id'] = doc_id
+                    document_node.metadata['doc_hash'] = hash
+                    # delete Spliter character
+                    page_content = document_node.page_content
+                    if page_content.startswith(".") or page_content.startswith("。"):
+                        page_content = page_content[1:]
+                    else:
+                        page_content = page_content
+                    document_node.page_content = page_content
+                    split_documents.append(document_node)
+            all_documents.extend(split_documents)
+        return all_documents
+
+    def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
+        if dataset.indexing_technique == 'high_quality':
+            vector = Vector(dataset)
+            vector.create(documents)
+        if with_keywords:
+            keyword = Keyword(dataset)
+            keyword.create(documents)
+
+    def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
+        if dataset.indexing_technique == 'high_quality':
+            vector = Vector(dataset)
+            if node_ids:
+                vector.delete_by_ids(node_ids)
+            else:
+                vector.delete()
+        if with_keywords:
+            keyword = Keyword(dataset)
+            if node_ids:
+                keyword.delete_by_ids(node_ids)
+            else:
+                keyword.delete()
+
+    def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
+                 score_threshold: float, reranking_model: dict) -> list[Document]:
+        # Set search parameters.
+        results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query,
+                                            top_k=top_k, score_threshold=score_threshold,
+                                            reranking_model=reranking_model)
+        # Organize results.
+        docs = []
+        for result in results:
+            metadata = result.metadata
+            metadata['score'] = result.score
+            if result.score > score_threshold:
+                doc = Document(page_content=result.page_content, metadata=metadata)
+                docs.append(doc)
+        return docs

+ 161 - 0
api/core/rag/index_processor/processor/qa_index_processor.py

@@ -0,0 +1,161 @@
+"""Paragraph index processor."""
+import logging
+import re
+import threading
+import uuid
+from typing import Optional
+
+import pandas as pd
+from flask import Flask, current_app
+from flask_login import current_user
+from werkzeug.datastructures import FileStorage
+
+from core.generator.llm_generator import LLMGenerator
+from core.rag.cleaner.clean_processor import CleanProcessor
+from core.rag.datasource.retrieval_service import RetrievalService
+from core.rag.datasource.vdb.vector_factory import Vector
+from core.rag.extractor.entity.extract_setting import ExtractSetting
+from core.rag.extractor.extract_processor import ExtractProcessor
+from core.rag.index_processor.index_processor_base import BaseIndexProcessor
+from core.rag.models.document import Document
+from libs import helper
+from models.dataset import Dataset
+
+
+class QAIndexProcessor(BaseIndexProcessor):
+    def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
+
+        text_docs = ExtractProcessor.extract(extract_setting=extract_setting,
+                                             is_automatic=kwargs.get('process_rule_mode') == "automatic")
+        return text_docs
+
+    def transform(self, documents: list[Document], **kwargs) -> list[Document]:
+        splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'),
+                                      embedding_model_instance=None)
+
+        # Split the text documents into nodes.
+        all_documents = []
+        all_qa_documents = []
+        for document in documents:
+            # document clean
+            document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule'))
+            document.page_content = document_text
+
+            # parse document to nodes
+            document_nodes = splitter.split_documents([document])
+            split_documents = []
+            for document_node in document_nodes:
+
+                if document_node.page_content.strip():
+                    doc_id = str(uuid.uuid4())
+                    hash = helper.generate_text_hash(document_node.page_content)
+                    document_node.metadata['doc_id'] = doc_id
+                    document_node.metadata['doc_hash'] = hash
+                    # delete Spliter character
+                    page_content = document_node.page_content
+                    if page_content.startswith(".") or page_content.startswith("。"):
+                        page_content = page_content[1:]
+                    else:
+                        page_content = page_content
+                    document_node.page_content = page_content
+                    split_documents.append(document_node)
+            all_documents.extend(split_documents)
+        for i in range(0, len(all_documents), 10):
+            threads = []
+            sub_documents = all_documents[i:i + 10]
+            for doc in sub_documents:
+                document_format_thread = threading.Thread(target=self._format_qa_document, kwargs={
+                    'flask_app': current_app._get_current_object(),
+                    'tenant_id': current_user.current_tenant.id,
+                    'document_node': doc,
+                    'all_qa_documents': all_qa_documents,
+                    'document_language': kwargs.get('document_language', 'English')})
+                threads.append(document_format_thread)
+                document_format_thread.start()
+            for thread in threads:
+                thread.join()
+        return all_qa_documents
+
+    def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]:
+
+        # check file type
+        if not file.filename.endswith('.csv'):
+            raise ValueError("Invalid file type. Only CSV files are allowed")
+
+        try:
+            # Skip the first row
+            df = pd.read_csv(file)
+            text_docs = []
+            for index, row in df.iterrows():
+                data = Document(page_content=row[0], metadata={'answer': row[1]})
+                text_docs.append(data)
+            if len(text_docs) == 0:
+                raise ValueError("The CSV file is empty.")
+
+        except Exception as e:
+            raise ValueError(str(e))
+        return text_docs
+
+    def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
+        if dataset.indexing_technique == 'high_quality':
+            vector = Vector(dataset)
+            vector.create(documents)
+
+    def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
+        vector = Vector(dataset)
+        if node_ids:
+            vector.delete_by_ids(node_ids)
+        else:
+            vector.delete()
+
+    def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
+                 score_threshold: float, reranking_model: dict):
+        # Set search parameters.
+        results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query,
+                                            top_k=top_k, score_threshold=score_threshold,
+                                            reranking_model=reranking_model)
+        # Organize results.
+        docs = []
+        for result in results:
+            metadata = result.metadata
+            metadata['score'] = result.score
+            if result.score > score_threshold:
+                doc = Document(page_content=result.page_content, metadata=metadata)
+                docs.append(doc)
+        return docs
+
+    def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
+        format_documents = []
+        if document_node.page_content is None or not document_node.page_content.strip():
+            return
+        with flask_app.app_context():
+            try:
+                # qa model document
+                response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language)
+                document_qa_list = self._format_split_text(response)
+                qa_documents = []
+                for result in document_qa_list:
+                    qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
+                    doc_id = str(uuid.uuid4())
+                    hash = helper.generate_text_hash(result['question'])
+                    qa_document.metadata['answer'] = result['answer']
+                    qa_document.metadata['doc_id'] = doc_id
+                    qa_document.metadata['doc_hash'] = hash
+                    qa_documents.append(qa_document)
+                format_documents.extend(qa_documents)
+            except Exception as e:
+                logging.exception(e)
+
+            all_qa_documents.extend(format_documents)
+
+    def _format_split_text(self, text):
+        regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
+        matches = re.findall(regex, text, re.UNICODE)
+
+        return [
+            {
+                "question": q,
+                "answer": re.sub(r"\n\s*", "\n", a.strip())
+            }
+            for q, a in matches if q and a
+        ]

+ 0 - 0
api/core/rag/models/__init__.py


+ 16 - 0
api/core/rag/models/document.py

@@ -0,0 +1,16 @@
+from typing import Optional
+
+from pydantic import BaseModel, Field
+
+
+class Document(BaseModel):
+    """Class for storing a piece of text and associated metadata."""
+
+    page_content: str
+
+    """Arbitrary metadata about the page content (e.g., source, relationships to other
+        documents, etc.).
+    """
+    metadata: Optional[dict] = Field(default_factory=dict)
+
+

+ 5 - 5
api/core/tool/web_reader_tool.py

@@ -21,9 +21,9 @@ from pydantic import BaseModel, Field
 from regex import regex
 from regex import regex
 
 
 from core.chain.llm_chain import LLMChain
 from core.chain.llm_chain import LLMChain
-from core.data_loader import file_extractor
-from core.data_loader.file_extractor import FileExtractor
 from core.entities.application_entities import ModelConfigEntity
 from core.entities.application_entities import ModelConfigEntity
+from core.rag.extractor import extract_processor
+from core.rag.extractor.extract_processor import ExtractProcessor
 
 
 FULL_TEMPLATE = """
 FULL_TEMPLATE = """
 TITLE: {title}
 TITLE: {title}
@@ -146,7 +146,7 @@ def get_url(url: str) -> str:
     headers = {
     headers = {
         "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
         "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
     }
     }
-    supported_content_types = file_extractor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
+    supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
 
 
     head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10))
     head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10))
 
 
@@ -158,8 +158,8 @@ def get_url(url: str) -> str:
     if main_content_type not in supported_content_types:
     if main_content_type not in supported_content_types:
         return "Unsupported content-type [{}] of URL.".format(main_content_type)
         return "Unsupported content-type [{}] of URL.".format(main_content_type)
 
 
-    if main_content_type in file_extractor.SUPPORT_URL_CONTENT_TYPES:
-        return FileExtractor.load_from_url(url, return_text=True)
+    if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
+        return ExtractProcessor.load_from_url(url, return_text=True)
 
 
     response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30))
     response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30))
     a = extract_using_readabilipy(response.text)
     a = extract_using_readabilipy(response.text)

+ 16 - 71
api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py

@@ -6,15 +6,12 @@ from langchain.tools import BaseTool
 from pydantic import BaseModel, Field
 from pydantic import BaseModel, Field
 
 
 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
-from core.embedding.cached_embedding import CacheEmbedding
-from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
-from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex
 from core.model_manager import ModelManager
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.entities.model_entities import ModelType
+from core.rag.datasource.retrieval_service import RetrievalService
 from core.rerank.rerank import RerankRunner
 from core.rerank.rerank import RerankRunner
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.dataset import Dataset, Document, DocumentSegment
 from models.dataset import Dataset, Document, DocumentSegment
-from services.retrieval_service import RetrievalService
 
 
 default_retrieval_model = {
 default_retrieval_model = {
     'search_method': 'semantic_search',
     'search_method': 'semantic_search',
@@ -174,76 +171,24 @@ class DatasetMultiRetrieverTool(BaseTool):
 
 
             if dataset.indexing_technique == "economy":
             if dataset.indexing_technique == "economy":
                 # use keyword table query
                 # use keyword table query
-                kw_table_index = KeywordTableIndex(
-                    dataset=dataset,
-                    config=KeywordTableConfig(
-                        max_keywords_per_chunk=5
-                    )
-                )
-
-                documents = kw_table_index.search(query, search_kwargs={'k': self.top_k})
+                documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
+                                                      dataset_id=dataset.id,
+                                                      query=query,
+                                                      top_k=self.top_k
+                                                      )
                 if documents:
                 if documents:
                     all_documents.extend(documents)
                     all_documents.extend(documents)
             else:
             else:
-
-                try:
-                    model_manager = ModelManager()
-                    embedding_model = model_manager.get_model_instance(
-                        tenant_id=dataset.tenant_id,
-                        provider=dataset.embedding_model_provider,
-                        model_type=ModelType.TEXT_EMBEDDING,
-                        model=dataset.embedding_model
-                    )
-                except LLMBadRequestError:
-                    return []
-                except ProviderTokenNotInitError:
-                    return []
-
-                embeddings = CacheEmbedding(embedding_model)
-
-                documents = []
-                threads = []
                 if self.top_k > 0:
                 if self.top_k > 0:
-                    # retrieval_model source with semantic
-                    if retrieval_model['search_method'] == 'semantic_search' or retrieval_model[
-                        'search_method'] == 'hybrid_search':
-                        embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
-                            'flask_app': current_app._get_current_object(),
-                            'dataset_id': str(dataset.id),
-                            'query': query,
-                            'top_k': self.top_k,
-                            'score_threshold': self.score_threshold,
-                            'reranking_model': None,
-                            'all_documents': documents,
-                            'search_method': 'hybrid_search',
-                            'embeddings': embeddings
-                        })
-                        threads.append(embedding_thread)
-                        embedding_thread.start()
-
-                    # retrieval_model source with full text
-                    if retrieval_model['search_method'] == 'full_text_search' or retrieval_model[
-                        'search_method'] == 'hybrid_search':
-                        full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search,
-                                                                  kwargs={
-                                                                      'flask_app': current_app._get_current_object(),
-                                                                      'dataset_id': str(dataset.id),
-                                                                      'query': query,
-                                                                      'search_method': 'hybrid_search',
-                                                                      'embeddings': embeddings,
-                                                                      'score_threshold': retrieval_model[
-                                                                          'score_threshold'] if retrieval_model[
-                                                                          'score_threshold_enabled'] else None,
-                                                                      'top_k': self.top_k,
-                                                                      'reranking_model': retrieval_model[
-                                                                          'reranking_model'] if retrieval_model[
-                                                                          'reranking_enable'] else None,
-                                                                      'all_documents': documents
-                                                                  })
-                        threads.append(full_text_index_thread)
-                        full_text_index_thread.start()
-
-                    for thread in threads:
-                        thread.join()
+                    # retrieval source
+                    documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
+                                                          dataset_id=dataset.id,
+                                                          query=query,
+                                                          top_k=self.top_k,
+                                                          score_threshold=retrieval_model['score_threshold']
+                                                          if retrieval_model['score_threshold_enabled'] else None,
+                                                          reranking_model=retrieval_model['reranking_model']
+                                                          if retrieval_model['reranking_enable'] else None
+                                                          )
 
 
                     all_documents.extend(documents)
                     all_documents.extend(documents)

+ 17 - 95
api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py

@@ -1,20 +1,12 @@
-import threading
 from typing import Optional
 from typing import Optional
 
 
-from flask import current_app
 from langchain.tools import BaseTool
 from langchain.tools import BaseTool
 from pydantic import BaseModel, Field
 from pydantic import BaseModel, Field
 
 
 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
-from core.embedding.cached_embedding import CacheEmbedding
-from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex
-from core.model_manager import ModelManager
-from core.model_runtime.entities.model_entities import ModelType
-from core.model_runtime.errors.invoke import InvokeAuthorizationError
-from core.rerank.rerank import RerankRunner
+from core.rag.datasource.retrieval_service import RetrievalService
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.dataset import Dataset, Document, DocumentSegment
 from models.dataset import Dataset, Document, DocumentSegment
-from services.retrieval_service import RetrievalService
 
 
 default_retrieval_model = {
 default_retrieval_model = {
     'search_method': 'semantic_search',
     'search_method': 'semantic_search',
@@ -77,94 +69,24 @@ class DatasetRetrieverTool(BaseTool):
         retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
         retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
         if dataset.indexing_technique == "economy":
         if dataset.indexing_technique == "economy":
             # use keyword table query
             # use keyword table query
-            kw_table_index = KeywordTableIndex(
-                dataset=dataset,
-                config=KeywordTableConfig(
-                    max_keywords_per_chunk=5
-                )
-            )
-
-            documents = kw_table_index.search(query, search_kwargs={'k': self.top_k})
+            documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
+                                                  dataset_id=dataset.id,
+                                                  query=query,
+                                                  top_k=self.top_k
+                                                  )
             return str("\n".join([document.page_content for document in documents]))
             return str("\n".join([document.page_content for document in documents]))
         else:
         else:
-            # get embedding model instance
-            try:
-                model_manager = ModelManager()
-                embedding_model = model_manager.get_model_instance(
-                    tenant_id=dataset.tenant_id,
-                    provider=dataset.embedding_model_provider,
-                    model_type=ModelType.TEXT_EMBEDDING,
-                    model=dataset.embedding_model
-                )
-            except InvokeAuthorizationError:
-                return ''
-
-            embeddings = CacheEmbedding(embedding_model)
-
-            documents = []
-            threads = []
             if self.top_k > 0:
             if self.top_k > 0:
-                # retrieval source with semantic
-                if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
-                    embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
-                        'flask_app': current_app._get_current_object(),
-                        'dataset_id': str(dataset.id),
-                        'query': query,
-                        'top_k': self.top_k,
-                        'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
-                            'score_threshold_enabled'] else None,
-                        'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
-                            'reranking_enable'] else None,
-                        'all_documents': documents,
-                        'search_method': retrieval_model['search_method'],
-                        'embeddings': embeddings
-                    })
-                    threads.append(embedding_thread)
-                    embedding_thread.start()
-
-                # retrieval_model source with full text
-                if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
-                    full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
-                        'flask_app': current_app._get_current_object(),
-                        'dataset_id': str(dataset.id),
-                        'query': query,
-                        'search_method': retrieval_model['search_method'],
-                        'embeddings': embeddings,
-                        'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
-                            'score_threshold_enabled'] else None,
-                        'top_k': self.top_k,
-                        'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
-                            'reranking_enable'] else None,
-                        'all_documents': documents
-                    })
-                    threads.append(full_text_index_thread)
-                    full_text_index_thread.start()
-
-                for thread in threads:
-                    thread.join()
-
-                # hybrid search: rerank after all documents have been searched
-                if retrieval_model['search_method'] == 'hybrid_search':
-                    # get rerank model instance
-                    try:
-                        model_manager = ModelManager()
-                        rerank_model_instance = model_manager.get_model_instance(
-                            tenant_id=dataset.tenant_id,
-                            provider=retrieval_model['reranking_model']['reranking_provider_name'],
-                            model_type=ModelType.RERANK,
-                            model=retrieval_model['reranking_model']['reranking_model_name']
-                        )
-                    except InvokeAuthorizationError:
-                        return ''
-
-                    rerank_runner = RerankRunner(rerank_model_instance)
-                    documents = rerank_runner.run(
-                        query=query,
-                        documents=documents,
-                        score_threshold=retrieval_model['score_threshold'] if retrieval_model[
-                            'score_threshold_enabled'] else None,
-                        top_n=self.top_k
-                    )
+                # retrieval source
+                documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
+                                                      dataset_id=dataset.id,
+                                                      query=query,
+                                                      top_k=self.top_k,
+                                                      score_threshold=retrieval_model['score_threshold']
+                                                      if retrieval_model['score_threshold_enabled'] else None,
+                                                      reranking_model=retrieval_model['reranking_model']
+                                                      if retrieval_model['reranking_enable'] else None
+                                                      )
             else:
             else:
                 documents = []
                 documents = []
 
 
@@ -234,4 +156,4 @@ class DatasetRetrieverTool(BaseTool):
             return str("\n".join(document_context_list))
             return str("\n".join(document_context_list))
 
 
     async def _arun(self, tool_input: str) -> str:
     async def _arun(self, tool_input: str) -> str:
-        raise NotImplementedError()
+        raise NotImplementedError()

+ 5 - 5
api/core/tools/utils/web_reader_tool.py

@@ -21,9 +21,9 @@ from pydantic import BaseModel, Field
 from regex import regex
 from regex import regex
 
 
 from core.chain.llm_chain import LLMChain
 from core.chain.llm_chain import LLMChain
-from core.data_loader import file_extractor
-from core.data_loader.file_extractor import FileExtractor
 from core.entities.application_entities import ModelConfigEntity
 from core.entities.application_entities import ModelConfigEntity
+from core.rag.extractor import extract_processor
+from core.rag.extractor.extract_processor import ExtractProcessor
 
 
 FULL_TEMPLATE = """
 FULL_TEMPLATE = """
 TITLE: {title}
 TITLE: {title}
@@ -149,7 +149,7 @@ def get_url(url: str, user_agent: str = None) -> str:
     if user_agent:
     if user_agent:
         headers["User-Agent"] = user_agent
         headers["User-Agent"] = user_agent
     
     
-    supported_content_types = file_extractor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
+    supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
 
 
     head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10))
     head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10))
 
 
@@ -161,8 +161,8 @@ def get_url(url: str, user_agent: str = None) -> str:
     if main_content_type not in supported_content_types:
     if main_content_type not in supported_content_types:
         return "Unsupported content-type [{}] of URL.".format(main_content_type)
         return "Unsupported content-type [{}] of URL.".format(main_content_type)
 
 
-    if main_content_type in file_extractor.SUPPORT_URL_CONTENT_TYPES:
-        return FileExtractor.load_from_url(url, return_text=True)
+    if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
+        return ExtractProcessor.load_from_url(url, return_text=True)
 
 
     response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30))
     response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30))
     a = extract_using_readabilipy(response.text)
     a = extract_using_readabilipy(response.text)

+ 0 - 56
api/core/vector_store/milvus_vector_store.py

@@ -1,56 +0,0 @@
-from core.vector_store.vector.milvus import Milvus
-
-
-class MilvusVectorStore(Milvus):
-    def del_texts(self, where_filter: dict):
-        if not where_filter:
-            raise ValueError('where_filter must not be empty')
-
-        self.col.delete(where_filter.get('filter'))
-
-    def del_text(self, uuid: str) -> None:
-        expr = f"id == {uuid}"
-        self.col.delete(expr)
-
-    def text_exists(self, uuid: str) -> bool:
-        result = self.col.query(
-            expr=f'metadata["doc_id"] == "{uuid}"',
-            output_fields=["id"]
-        )
-
-        return len(result) > 0
-
-    def get_ids_by_document_id(self, document_id: str):
-        result = self.col.query(
-            expr=f'metadata["document_id"] == "{document_id}"',
-            output_fields=["id"]
-        )
-        if result:
-            return [item["id"] for item in result]
-        else:
-            return None
-
-    def get_ids_by_metadata_field(self, key: str, value: str):
-        result = self.col.query(
-            expr=f'metadata["{key}"] == "{value}"',
-            output_fields=["id"]
-        )
-        if result:
-            return [item["id"] for item in result]
-        else:
-            return None
-
-    def get_ids_by_doc_ids(self, doc_ids: list):
-        result = self.col.query(
-            expr=f'metadata["doc_id"] in {doc_ids}',
-            output_fields=["id"]
-        )
-        if result:
-            return [item["id"] for item in result]
-        else:
-            return None
-
-    def delete(self):
-        from pymilvus import utility
-        utility.drop_collection(self.collection_name, None, self.alias)
-

+ 0 - 76
api/core/vector_store/qdrant_vector_store.py

@@ -1,76 +0,0 @@
-from typing import Any, cast
-
-from langchain.schema import Document
-from qdrant_client.http.models import Filter, FilterSelector, PointIdsList
-from qdrant_client.local.qdrant_local import QdrantLocal
-
-from core.vector_store.vector.qdrant import Qdrant
-
-
-class QdrantVectorStore(Qdrant):
-    def del_texts(self, filter: Filter):
-        if not filter:
-            raise ValueError('filter must not be empty')
-
-        self._reload_if_needed()
-
-        self.client.delete(
-            collection_name=self.collection_name,
-            points_selector=FilterSelector(
-                filter=filter
-            ),
-        )
-
-    def del_text(self, uuid: str) -> None:
-        self._reload_if_needed()
-
-        self.client.delete(
-            collection_name=self.collection_name,
-            points_selector=PointIdsList(
-                points=[uuid],
-            ),
-        )
-
-    def text_exists(self, uuid: str) -> bool:
-        self._reload_if_needed()
-
-        response = self.client.retrieve(
-            collection_name=self.collection_name,
-            ids=[uuid]
-        )
-
-        return len(response) > 0
-
-    def delete(self):
-        self._reload_if_needed()
-
-        self.client.delete_collection(collection_name=self.collection_name)
-
-    def delete_group(self):
-        self._reload_if_needed()
-
-        self.client.delete_collection(collection_name=self.collection_name)
-
-    @classmethod
-    def _document_from_scored_point(
-            cls,
-            scored_point: Any,
-            content_payload_key: str,
-            metadata_payload_key: str,
-    ) -> Document:
-        if scored_point.payload.get('doc_id'):
-            return Document(
-                page_content=scored_point.payload.get(content_payload_key),
-                metadata={'doc_id': scored_point.id}
-            )
-
-        return Document(
-            page_content=scored_point.payload.get(content_payload_key),
-            metadata=scored_point.payload.get(metadata_payload_key) or {},
-        )
-
-    def _reload_if_needed(self):
-        if isinstance(self.client, QdrantLocal):
-            self.client = cast(QdrantLocal, self.client)
-            self.client._load()
-

+ 0 - 852
api/core/vector_store/vector/milvus.py

@@ -1,852 +0,0 @@
-"""Wrapper around the Milvus vector database."""
-from __future__ import annotations
-
-import logging
-from collections.abc import Iterable, Sequence
-from typing import Any, Optional, Union
-from uuid import uuid4
-
-import numpy as np
-from langchain.docstore.document import Document
-from langchain.embeddings.base import Embeddings
-from langchain.vectorstores.base import VectorStore
-from langchain.vectorstores.utils import maximal_marginal_relevance
-
-logger = logging.getLogger(__name__)
-
-DEFAULT_MILVUS_CONNECTION = {
-    "host": "localhost",
-    "port": "19530",
-    "user": "",
-    "password": "",
-    "secure": False,
-}
-
-
-class Milvus(VectorStore):
-    """Initialize wrapper around the milvus vector database.
-
-    In order to use this you need to have `pymilvus` installed and a
-    running Milvus
-
-    See the following documentation for how to run a Milvus instance:
-    https://milvus.io/docs/install_standalone-docker.md
-
-    If looking for a hosted Milvus, take a look at this documentation:
-    https://zilliz.com/cloud and make use of the Zilliz vectorstore found in
-    this project,
-
-    IF USING L2/IP metric IT IS HIGHLY SUGGESTED TO NORMALIZE YOUR DATA.
-
-    Args:
-        embedding_function (Embeddings): Function used to embed the text.
-        collection_name (str): Which Milvus collection to use. Defaults to
-            "LangChainCollection".
-        connection_args (Optional[dict[str, any]]): The connection args used for
-            this class comes in the form of a dict.
-        consistency_level (str): The consistency level to use for a collection.
-            Defaults to "Session".
-        index_params (Optional[dict]): Which index params to use. Defaults to
-            HNSW/AUTOINDEX depending on service.
-        search_params (Optional[dict]): Which search params to use. Defaults to
-            default of index.
-        drop_old (Optional[bool]): Whether to drop the current collection. Defaults
-            to False.
-
-    The connection args used for this class comes in the form of a dict,
-    here are a few of the options:
-        address (str): The actual address of Milvus
-            instance. Example address: "localhost:19530"
-        uri (str): The uri of Milvus instance. Example uri:
-            "http://randomwebsite:19530",
-            "tcp:foobarsite:19530",
-            "https://ok.s3.south.com:19530".
-        host (str): The host of Milvus instance. Default at "localhost",
-            PyMilvus will fill in the default host if only port is provided.
-        port (str/int): The port of Milvus instance. Default at 19530, PyMilvus
-            will fill in the default port if only host is provided.
-        user (str): Use which user to connect to Milvus instance. If user and
-            password are provided, we will add related header in every RPC call.
-        password (str): Required when user is provided. The password
-            corresponding to the user.
-        secure (bool): Default is false. If set to true, tls will be enabled.
-        client_key_path (str): If use tls two-way authentication, need to
-            write the client.key path.
-        client_pem_path (str): If use tls two-way authentication, need to
-            write the client.pem path.
-        ca_pem_path (str): If use tls two-way authentication, need to write
-            the ca.pem path.
-        server_pem_path (str): If use tls one-way authentication, need to
-            write the server.pem path.
-        server_name (str): If use tls, need to write the common name.
-
-    Example:
-        .. code-block:: python
-
-        from langchain import Milvus
-        from langchain.embeddings import OpenAIEmbeddings
-
-        embedding = OpenAIEmbeddings()
-        # Connect to a milvus instance on localhost
-        milvus_store = Milvus(
-            embedding_function = Embeddings,
-            collection_name = "LangChainCollection",
-            drop_old = True,
-        )
-
-    Raises:
-        ValueError: If the pymilvus python package is not installed.
-    """
-
-    def __init__(
-        self,
-        embedding_function: Embeddings,
-        collection_name: str = "LangChainCollection",
-        connection_args: Optional[dict[str, Any]] = None,
-        consistency_level: str = "Session",
-        index_params: Optional[dict] = None,
-        search_params: Optional[dict] = None,
-        drop_old: Optional[bool] = False,
-    ):
-        """Initialize the Milvus vector store."""
-        try:
-            from pymilvus import Collection, utility
-        except ImportError:
-            raise ValueError(
-                "Could not import pymilvus python package. "
-                "Please install it with `pip install pymilvus`."
-            )
-
-        # Default search params when one is not provided.
-        self.default_search_params = {
-            "IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}},
-            "IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}},
-            "IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}},
-            "HNSW": {"metric_type": "L2", "params": {"ef": 10}},
-            "RHNSW_FLAT": {"metric_type": "L2", "params": {"ef": 10}},
-            "RHNSW_SQ": {"metric_type": "L2", "params": {"ef": 10}},
-            "RHNSW_PQ": {"metric_type": "L2", "params": {"ef": 10}},
-            "IVF_HNSW": {"metric_type": "L2", "params": {"nprobe": 10, "ef": 10}},
-            "ANNOY": {"metric_type": "L2", "params": {"search_k": 10}},
-            "AUTOINDEX": {"metric_type": "L2", "params": {}},
-        }
-
-        self.embedding_func = embedding_function
-        self.collection_name = collection_name
-        self.index_params = index_params
-        self.search_params = search_params
-        self.consistency_level = consistency_level
-
-        # In order for a collection to be compatible, pk needs to be auto'id and int
-        self._primary_field = "id"
-        # In order for compatibility, the text field will need to be called "text"
-        self._text_field = "page_content"
-        # In order for compatibility, the vector field needs to be called "vector"
-        self._vector_field = "vectors"
-        # In order for compatibility, the metadata field will need to be called "metadata"
-        self._metadata_field = "metadata"
-        self.fields: list[str] = []
-        # Create the connection to the server
-        if connection_args is None:
-            connection_args = DEFAULT_MILVUS_CONNECTION
-        self.alias = self._create_connection_alias(connection_args)
-        self.col: Optional[Collection] = None
-
-        # Grab the existing collection if it exists
-        if utility.has_collection(self.collection_name, using=self.alias):
-            self.col = Collection(
-                self.collection_name,
-                using=self.alias,
-            )
-        # If need to drop old, drop it
-        if drop_old and isinstance(self.col, Collection):
-            self.col.drop()
-            self.col = None
-
-        # Initialize the vector store
-        self._init()
-
-    @property
-    def embeddings(self) -> Embeddings:
-        return self.embedding_func
-
-    def _create_connection_alias(self, connection_args: dict) -> str:
-        """Create the connection to the Milvus server."""
-        from pymilvus import MilvusException, connections
-
-        # Grab the connection arguments that are used for checking existing connection
-        host: str = connection_args.get("host", None)
-        port: Union[str, int] = connection_args.get("port", None)
-        address: str = connection_args.get("address", None)
-        uri: str = connection_args.get("uri", None)
-        user = connection_args.get("user", None)
-
-        # Order of use is host/port, uri, address
-        if host is not None and port is not None:
-            given_address = str(host) + ":" + str(port)
-        elif uri is not None:
-            given_address = uri.split("https://")[1]
-        elif address is not None:
-            given_address = address
-        else:
-            given_address = None
-            logger.debug("Missing standard address type for reuse atttempt")
-
-        # User defaults to empty string when getting connection info
-        if user is not None:
-            tmp_user = user
-        else:
-            tmp_user = ""
-
-        # If a valid address was given, then check if a connection exists
-        if given_address is not None:
-            for con in connections.list_connections():
-                addr = connections.get_connection_addr(con[0])
-                if (
-                    con[1]
-                    and ("address" in addr)
-                    and (addr["address"] == given_address)
-                    and ("user" in addr)
-                    and (addr["user"] == tmp_user)
-                ):
-                    logger.debug("Using previous connection: %s", con[0])
-                    return con[0]
-
-        # Generate a new connection if one doesn't exist
-        alias = uuid4().hex
-        try:
-            connections.connect(alias=alias, **connection_args)
-            logger.debug("Created new connection using: %s", alias)
-            return alias
-        except MilvusException as e:
-            logger.error("Failed to create new connection using: %s", alias)
-            raise e
-
-    def _init(
-        self, embeddings: Optional[list] = None, metadatas: Optional[list[dict]] = None
-    ) -> None:
-        if embeddings is not None:
-            self._create_collection(embeddings, metadatas)
-        self._extract_fields()
-        self._create_index()
-        self._create_search_params()
-        self._load()
-
-    def _create_collection(
-        self, embeddings: list, metadatas: Optional[list[dict]] = None
-    ) -> None:
-        from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, MilvusException
-        from pymilvus.orm.types import infer_dtype_bydata
-
-        # Determine embedding dim
-        dim = len(embeddings[0])
-        fields = []
-        # Determine metadata schema
-        # if metadatas:
-        #     # Create FieldSchema for each entry in metadata.
-        #     for key, value in metadatas[0].items():
-        #         # Infer the corresponding datatype of the metadata
-        #         dtype = infer_dtype_bydata(value)
-        #         # Datatype isn't compatible
-        #         if dtype == DataType.UNKNOWN or dtype == DataType.NONE:
-        #             logger.error(
-        #                 "Failure to create collection, unrecognized dtype for key: %s",
-        #                 key,
-        #             )
-        #             raise ValueError(f"Unrecognized datatype for {key}.")
-        #         # Dataype is a string/varchar equivalent
-        #         elif dtype == DataType.VARCHAR:
-        #             fields.append(FieldSchema(key, DataType.VARCHAR, max_length=65_535))
-        #         else:
-        #             fields.append(FieldSchema(key, dtype))
-        if metadatas:
-            fields.append(FieldSchema(self._metadata_field, DataType.JSON, max_length=65_535))
-
-        # Create the text field
-        fields.append(
-            FieldSchema(self._text_field, DataType.VARCHAR, max_length=65_535)
-        )
-        # Create the primary key field
-        fields.append(
-            FieldSchema(
-                self._primary_field, DataType.INT64, is_primary=True, auto_id=True
-            )
-        )
-        # Create the vector field, supports binary or float vectors
-        fields.append(
-            FieldSchema(self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim)
-        )
-
-        # Create the schema for the collection
-        schema = CollectionSchema(fields)
-
-        # Create the collection
-        try:
-            self.col = Collection(
-                name=self.collection_name,
-                schema=schema,
-                consistency_level=self.consistency_level,
-                using=self.alias,
-            )
-        except MilvusException as e:
-            logger.error(
-                "Failed to create collection: %s error: %s", self.collection_name, e
-            )
-            raise e
-
-    def _extract_fields(self) -> None:
-        """Grab the existing fields from the Collection"""
-        from pymilvus import Collection
-
-        if isinstance(self.col, Collection):
-            schema = self.col.schema
-            for x in schema.fields:
-                self.fields.append(x.name)
-            # Since primary field is auto-id, no need to track it
-            self.fields.remove(self._primary_field)
-
-    def _get_index(self) -> Optional[dict[str, Any]]:
-        """Return the vector index information if it exists"""
-        from pymilvus import Collection
-
-        if isinstance(self.col, Collection):
-            for x in self.col.indexes:
-                if x.field_name == self._vector_field:
-                    return x.to_dict()
-        return None
-
-    def _create_index(self) -> None:
-        """Create a index on the collection"""
-        from pymilvus import Collection, MilvusException
-
-        if isinstance(self.col, Collection) and self._get_index() is None:
-            try:
-                # If no index params, use a default HNSW based one
-                if self.index_params is None:
-                    self.index_params = {
-                        "metric_type": "IP",
-                        "index_type": "HNSW",
-                        "params": {"M": 8, "efConstruction": 64},
-                    }
-
-                try:
-                    self.col.create_index(
-                        self._vector_field,
-                        index_params=self.index_params,
-                        using=self.alias,
-                    )
-
-                # If default did not work, most likely on Zilliz Cloud
-                except MilvusException:
-                    # Use AUTOINDEX based index
-                    self.index_params = {
-                        "metric_type": "L2",
-                        "index_type": "AUTOINDEX",
-                        "params": {},
-                    }
-                    self.col.create_index(
-                        self._vector_field,
-                        index_params=self.index_params,
-                        using=self.alias,
-                    )
-                logger.debug(
-                    "Successfully created an index on collection: %s",
-                    self.collection_name,
-                )
-
-            except MilvusException as e:
-                logger.error(
-                    "Failed to create an index on collection: %s", self.collection_name
-                )
-                raise e
-
-    def _create_search_params(self) -> None:
-        """Generate search params based on the current index type"""
-        from pymilvus import Collection
-
-        if isinstance(self.col, Collection) and self.search_params is None:
-            index = self._get_index()
-            if index is not None:
-                index_type: str = index["index_param"]["index_type"]
-                metric_type: str = index["index_param"]["metric_type"]
-                self.search_params = self.default_search_params[index_type]
-                self.search_params["metric_type"] = metric_type
-
-    def _load(self) -> None:
-        """Load the collection if available."""
-        from pymilvus import Collection
-
-        if isinstance(self.col, Collection) and self._get_index() is not None:
-            self.col.load()
-
-    def add_texts(
-        self,
-        texts: Iterable[str],
-        metadatas: Optional[list[dict]] = None,
-        timeout: Optional[int] = None,
-        batch_size: int = 1000,
-        **kwargs: Any,
-    ) -> list[str]:
-        """Insert text data into Milvus.
-
-        Inserting data when the collection has not be made yet will result
-        in creating a new Collection. The data of the first entity decides
-        the schema of the new collection, the dim is extracted from the first
-        embedding and the columns are decided by the first metadata dict.
-        Metada keys will need to be present for all inserted values. At
-        the moment there is no None equivalent in Milvus.
-
-        Args:
-            texts (Iterable[str]): The texts to embed, it is assumed
-                that they all fit in memory.
-            metadatas (Optional[List[dict]]): Metadata dicts attached to each of
-                the texts. Defaults to None.
-            timeout (Optional[int]): Timeout for each batch insert. Defaults
-                to None.
-            batch_size (int, optional): Batch size to use for insertion.
-                Defaults to 1000.
-
-        Raises:
-            MilvusException: Failure to add texts
-
-        Returns:
-            List[str]: The resulting keys for each inserted element.
-        """
-        from pymilvus import Collection, MilvusException
-
-        texts = list(texts)
-
-        try:
-            embeddings = self.embedding_func.embed_documents(texts)
-        except NotImplementedError:
-            embeddings = [self.embedding_func.embed_query(x) for x in texts]
-
-        if len(embeddings) == 0:
-            logger.debug("Nothing to insert, skipping.")
-            return []
-
-        # If the collection hasn't been initialized yet, perform all steps to do so
-        if not isinstance(self.col, Collection):
-            self._init(embeddings, metadatas)
-
-        # Dict to hold all insert columns
-        insert_dict: dict[str, list] = {
-            self._text_field: texts,
-            self._vector_field: embeddings,
-        }
-
-        # Collect the metadata into the insert dict.
-        # if metadatas is not None:
-        #     for d in metadatas:
-        #         for key, value in d.items():
-        #             if key in self.fields:
-        #                 insert_dict.setdefault(key, []).append(value)
-        if metadatas is not None:
-            for d in metadatas:
-                insert_dict.setdefault(self._metadata_field, []).append(d)
-
-        # Total insert count
-        vectors: list = insert_dict[self._vector_field]
-        total_count = len(vectors)
-
-        pks: list[str] = []
-
-        assert isinstance(self.col, Collection)
-        for i in range(0, total_count, batch_size):
-            # Grab end index
-            end = min(i + batch_size, total_count)
-            # Convert dict to list of lists batch for insertion
-            insert_list = [insert_dict[x][i:end] for x in self.fields]
-            # Insert into the collection.
-            try:
-                res: Collection
-                res = self.col.insert(insert_list, timeout=timeout, **kwargs)
-                pks.extend(res.primary_keys)
-            except MilvusException as e:
-                logger.error(
-                    "Failed to insert batch starting at entity: %s/%s", i, total_count
-                )
-                raise e
-        return pks
-
-    def similarity_search(
-        self,
-        query: str,
-        k: int = 4,
-        param: Optional[dict] = None,
-        expr: Optional[str] = None,
-        timeout: Optional[int] = None,
-        **kwargs: Any,
-    ) -> list[Document]:
-        """Perform a similarity search against the query string.
-
-        Args:
-            query (str): The text to search.
-            k (int, optional): How many results to return. Defaults to 4.
-            param (dict, optional): The search params for the index type.
-                Defaults to None.
-            expr (str, optional): Filtering expression. Defaults to None.
-            timeout (int, optional): How long to wait before timeout error.
-                Defaults to None.
-            kwargs: Collection.search() keyword arguments.
-
-        Returns:
-            List[Document]: Document results for search.
-        """
-        if self.col is None:
-            logger.debug("No existing collection to search.")
-            return []
-        res = self.similarity_search_with_score(
-            query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs
-        )
-        return [doc for doc, _ in res]
-
-    def similarity_search_by_vector(
-        self,
-        embedding: list[float],
-        k: int = 4,
-        param: Optional[dict] = None,
-        expr: Optional[str] = None,
-        timeout: Optional[int] = None,
-        **kwargs: Any,
-    ) -> list[Document]:
-        """Perform a similarity search against the query string.
-
-        Args:
-            embedding (List[float]): The embedding vector to search.
-            k (int, optional): How many results to return. Defaults to 4.
-            param (dict, optional): The search params for the index type.
-                Defaults to None.
-            expr (str, optional): Filtering expression. Defaults to None.
-            timeout (int, optional): How long to wait before timeout error.
-                Defaults to None.
-            kwargs: Collection.search() keyword arguments.
-
-        Returns:
-            List[Document]: Document results for search.
-        """
-        if self.col is None:
-            logger.debug("No existing collection to search.")
-            return []
-        res = self.similarity_search_with_score_by_vector(
-            embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
-        )
-        return [doc for doc, _ in res]
-
-    def similarity_search_with_score(
-        self,
-        query: str,
-        k: int = 4,
-        param: Optional[dict] = None,
-        expr: Optional[str] = None,
-        timeout: Optional[int] = None,
-        **kwargs: Any,
-    ) -> list[tuple[Document, float]]:
-        """Perform a search on a query string and return results with score.
-
-        For more information about the search parameters, take a look at the pymilvus
-        documentation found here:
-        https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
-
-        Args:
-            query (str): The text being searched.
-            k (int, optional): The amount of results to return. Defaults to 4.
-            param (dict): The search params for the specified index.
-                Defaults to None.
-            expr (str, optional): Filtering expression. Defaults to None.
-            timeout (int, optional): How long to wait before timeout error.
-                Defaults to None.
-            kwargs: Collection.search() keyword arguments.
-
-        Returns:
-            List[float], List[Tuple[Document, any, any]]:
-        """
-        if self.col is None:
-            logger.debug("No existing collection to search.")
-            return []
-
-        # Embed the query text.
-        embedding = self.embedding_func.embed_query(query)
-
-        res = self.similarity_search_with_score_by_vector(
-            embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
-        )
-        return res
-
-    def _similarity_search_with_relevance_scores(
-        self,
-        query: str,
-        k: int = 4,
-        **kwargs: Any,
-    ) -> list[tuple[Document, float]]:
-        """Return docs and relevance scores in the range [0, 1].
-
-        0 is dissimilar, 1 is most similar.
-
-        Args:
-            query: input text
-            k: Number of Documents to return. Defaults to 4.
-            **kwargs: kwargs to be passed to similarity search. Should include:
-                score_threshold: Optional, a floating point value between 0 to 1 to
-                    filter the resulting set of retrieved docs
-
-        Returns:
-            List of Tuples of (doc, similarity_score)
-        """
-        return self.similarity_search_with_score(query, k, **kwargs)
-
-    def similarity_search_with_score_by_vector(
-        self,
-        embedding: list[float],
-        k: int = 4,
-        param: Optional[dict] = None,
-        expr: Optional[str] = None,
-        timeout: Optional[int] = None,
-        **kwargs: Any,
-    ) -> list[tuple[Document, float]]:
-        """Perform a search on a query string and return results with score.
-
-        For more information about the search parameters, take a look at the pymilvus
-        documentation found here:
-        https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
-
-        Args:
-            embedding (List[float]): The embedding vector being searched.
-            k (int, optional): The amount of results to return. Defaults to 4.
-            param (dict): The search params for the specified index.
-                Defaults to None.
-            expr (str, optional): Filtering expression. Defaults to None.
-            timeout (int, optional): How long to wait before timeout error.
-                Defaults to None.
-            kwargs: Collection.search() keyword arguments.
-
-        Returns:
-            List[Tuple[Document, float]]: Result doc and score.
-        """
-        if self.col is None:
-            logger.debug("No existing collection to search.")
-            return []
-
-        if param is None:
-            param = self.search_params
-
-        # Determine result metadata fields.
-        output_fields = self.fields[:]
-        output_fields.remove(self._vector_field)
-
-        # Perform the search.
-        res = self.col.search(
-            data=[embedding],
-            anns_field=self._vector_field,
-            param=param,
-            limit=k,
-            expr=expr,
-            output_fields=output_fields,
-            timeout=timeout,
-            **kwargs,
-        )
-        # Organize results.
-        ret = []
-        for result in res[0]:
-            meta = {x: result.entity.get(x) for x in output_fields}
-            doc = Document(page_content=meta.pop(self._text_field), metadata=meta.get('metadata'))
-            pair = (doc, result.score)
-            ret.append(pair)
-
-        return ret
-
-    def max_marginal_relevance_search(
-        self,
-        query: str,
-        k: int = 4,
-        fetch_k: int = 20,
-        lambda_mult: float = 0.5,
-        param: Optional[dict] = None,
-        expr: Optional[str] = None,
-        timeout: Optional[int] = None,
-        **kwargs: Any,
-    ) -> list[Document]:
-        """Perform a search and return results that are reordered by MMR.
-
-        Args:
-            query (str): The text being searched.
-            k (int, optional): How many results to give. Defaults to 4.
-            fetch_k (int, optional): Total results to select k from.
-                Defaults to 20.
-            lambda_mult: Number between 0 and 1 that determines the degree
-                        of diversity among the results with 0 corresponding
-                        to maximum diversity and 1 to minimum diversity.
-                        Defaults to 0.5
-            param (dict, optional): The search params for the specified index.
-                Defaults to None.
-            expr (str, optional): Filtering expression. Defaults to None.
-            timeout (int, optional): How long to wait before timeout error.
-                Defaults to None.
-            kwargs: Collection.search() keyword arguments.
-
-
-        Returns:
-            List[Document]: Document results for search.
-        """
-        if self.col is None:
-            logger.debug("No existing collection to search.")
-            return []
-
-        embedding = self.embedding_func.embed_query(query)
-
-        return self.max_marginal_relevance_search_by_vector(
-            embedding=embedding,
-            k=k,
-            fetch_k=fetch_k,
-            lambda_mult=lambda_mult,
-            param=param,
-            expr=expr,
-            timeout=timeout,
-            **kwargs,
-        )
-
-    def max_marginal_relevance_search_by_vector(
-        self,
-        embedding: list[float],
-        k: int = 4,
-        fetch_k: int = 20,
-        lambda_mult: float = 0.5,
-        param: Optional[dict] = None,
-        expr: Optional[str] = None,
-        timeout: Optional[int] = None,
-        **kwargs: Any,
-    ) -> list[Document]:
-        """Perform a search and return results that are reordered by MMR.
-
-        Args:
-            embedding (str): The embedding vector being searched.
-            k (int, optional): How many results to give. Defaults to 4.
-            fetch_k (int, optional): Total results to select k from.
-                Defaults to 20.
-            lambda_mult: Number between 0 and 1 that determines the degree
-                        of diversity among the results with 0 corresponding
-                        to maximum diversity and 1 to minimum diversity.
-                        Defaults to 0.5
-            param (dict, optional): The search params for the specified index.
-                Defaults to None.
-            expr (str, optional): Filtering expression. Defaults to None.
-            timeout (int, optional): How long to wait before timeout error.
-                Defaults to None.
-            kwargs: Collection.search() keyword arguments.
-
-        Returns:
-            List[Document]: Document results for search.
-        """
-        if self.col is None:
-            logger.debug("No existing collection to search.")
-            return []
-
-        if param is None:
-            param = self.search_params
-
-        # Determine result metadata fields.
-        output_fields = self.fields[:]
-        output_fields.remove(self._vector_field)
-
-        # Perform the search.
-        res = self.col.search(
-            data=[embedding],
-            anns_field=self._vector_field,
-            param=param,
-            limit=fetch_k,
-            expr=expr,
-            output_fields=output_fields,
-            timeout=timeout,
-            **kwargs,
-        )
-        # Organize results.
-        ids = []
-        documents = []
-        scores = []
-        for result in res[0]:
-            meta = {x: result.entity.get(x) for x in output_fields}
-            doc = Document(page_content=meta.pop(self._text_field), metadata=meta)
-            documents.append(doc)
-            scores.append(result.score)
-            ids.append(result.id)
-
-        vectors = self.col.query(
-            expr=f"{self._primary_field} in {ids}",
-            output_fields=[self._primary_field, self._vector_field],
-            timeout=timeout,
-        )
-        # Reorganize the results from query to match search order.
-        vectors = {x[self._primary_field]: x[self._vector_field] for x in vectors}
-
-        ordered_result_embeddings = [vectors[x] for x in ids]
-
-        # Get the new order of results.
-        new_ordering = maximal_marginal_relevance(
-            np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult
-        )
-
-        # Reorder the values and return.
-        ret = []
-        for x in new_ordering:
-            # Function can return -1 index
-            if x == -1:
-                break
-            else:
-                ret.append(documents[x])
-        return ret
-
-    @classmethod
-    def from_texts(
-        cls,
-        texts: list[str],
-        embedding: Embeddings,
-        metadatas: Optional[list[dict]] = None,
-        collection_name: str = "LangChainCollection",
-        connection_args: dict[str, Any] = DEFAULT_MILVUS_CONNECTION,
-        consistency_level: str = "Session",
-        index_params: Optional[dict] = None,
-        search_params: Optional[dict] = None,
-        drop_old: bool = False,
-        batch_size: int = 100,
-        ids: Optional[Sequence[str]] = None,
-        **kwargs: Any,
-    ) -> Milvus:
-        """Create a Milvus collection, indexes it with HNSW, and insert data.
-
-        Args:
-            texts (List[str]): Text data.
-            embedding (Embeddings): Embedding function.
-            metadatas (Optional[List[dict]]): Metadata for each text if it exists.
-                Defaults to None.
-            collection_name (str, optional): Collection name to use. Defaults to
-                "LangChainCollection".
-            connection_args (dict[str, Any], optional): Connection args to use. Defaults
-                to DEFAULT_MILVUS_CONNECTION.
-            consistency_level (str, optional): Which consistency level to use. Defaults
-                to "Session".
-            index_params (Optional[dict], optional): Which index_params to use. Defaults
-                to None.
-            search_params (Optional[dict], optional): Which search params to use.
-                Defaults to None.
-            drop_old (Optional[bool], optional): Whether to drop the collection with
-                that name if it exists. Defaults to False.
-            batch_size:
-                How many vectors upload per-request.
-                Default: 100
-            ids: Optional[Sequence[str]] = None,
-
-        Returns:
-            Milvus: Milvus Vector Store
-        """
-        vector_db = cls(
-            embedding_function=embedding,
-            collection_name=collection_name,
-            connection_args=connection_args,
-            consistency_level=consistency_level,
-            index_params=index_params,
-            search_params=search_params,
-            drop_old=drop_old,
-            **kwargs,
-        )
-        vector_db.add_texts(texts=texts, metadatas=metadatas, batch_size=batch_size)
-        return vector_db

+ 0 - 1759
api/core/vector_store/vector/qdrant.py

@@ -1,1759 +0,0 @@
-"""Wrapper around Qdrant vector database."""
-from __future__ import annotations
-
-import asyncio
-import functools
-import uuid
-import warnings
-from collections.abc import Callable, Generator, Iterable, Sequence
-from itertools import islice
-from operator import itemgetter
-from typing import TYPE_CHECKING, Any, Optional, Union
-
-import numpy as np
-from langchain.docstore.document import Document
-from langchain.embeddings.base import Embeddings
-from langchain.vectorstores import VectorStore
-from langchain.vectorstores.utils import maximal_marginal_relevance
-from qdrant_client.http.models import PayloadSchemaType, TextIndexParams, TextIndexType, TokenizerType
-
-if TYPE_CHECKING:
-    from qdrant_client import grpc  # noqa
-    from qdrant_client.conversions import common_types
-    from qdrant_client.http import models as rest
-
-    DictFilter = dict[str, Union[str, int, bool, dict, list]]
-    MetadataFilter = Union[DictFilter, common_types.Filter]
-
-
-class QdrantException(Exception):
-    """Base class for all the Qdrant related exceptions"""
-
-
-def sync_call_fallback(method: Callable) -> Callable:
-    """
-    Decorator to call the synchronous method of the class if the async method is not
-    implemented. This decorator might be only used for the methods that are defined
-    as async in the class.
-    """
-
-    @functools.wraps(method)
-    async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
-        try:
-            return await method(self, *args, **kwargs)
-        except NotImplementedError:
-            # If the async method is not implemented, call the synchronous method
-            # by removing the first letter from the method name. For example,
-            # if the async method is called ``aaad_texts``, the synchronous method
-            # will be called ``aad_texts``.
-            sync_method = functools.partial(
-                getattr(self, method.__name__[1:]), *args, **kwargs
-            )
-            return await asyncio.get_event_loop().run_in_executor(None, sync_method)
-
-    return wrapper
-
-
-class Qdrant(VectorStore):
-    """Wrapper around Qdrant vector database.
-
-    To use you should have the ``qdrant-client`` package installed.
-
-    Example:
-        .. code-block:: python
-
-            from qdrant_client import QdrantClient
-            from langchain import Qdrant
-
-            client = QdrantClient()
-            collection_name = "MyCollection"
-            qdrant = Qdrant(client, collection_name, embedding_function)
-    """
-
-    CONTENT_KEY = "page_content"
-    METADATA_KEY = "metadata"
-    GROUP_KEY = "group_id"
-    VECTOR_NAME = None
-
-    def __init__(
-        self,
-        client: Any,
-        collection_name: str,
-        embeddings: Optional[Embeddings] = None,
-        content_payload_key: str = CONTENT_KEY,
-        metadata_payload_key: str = METADATA_KEY,
-        group_payload_key: str = GROUP_KEY,
-        group_id: str = None,
-        distance_strategy: str = "COSINE",
-        vector_name: Optional[str] = VECTOR_NAME,
-        embedding_function: Optional[Callable] = None,  # deprecated
-        is_new_collection: bool = False
-    ):
-        """Initialize with necessary components."""
-        try:
-            import qdrant_client
-        except ImportError:
-            raise ValueError(
-                "Could not import qdrant-client python package. "
-                "Please install it with `pip install qdrant-client`."
-            )
-
-        if not isinstance(client, qdrant_client.QdrantClient):
-            raise ValueError(
-                f"client should be an instance of qdrant_client.QdrantClient, "
-                f"got {type(client)}"
-            )
-
-        if embeddings is None and embedding_function is None:
-            raise ValueError(
-                "`embeddings` value can't be None. Pass `Embeddings` instance."
-            )
-
-        if embeddings is not None and embedding_function is not None:
-            raise ValueError(
-                "Both `embeddings` and `embedding_function` are passed. "
-                "Use `embeddings` only."
-            )
-
-        self._embeddings = embeddings
-        self._embeddings_function = embedding_function
-        self.client: qdrant_client.QdrantClient = client
-        self.collection_name = collection_name
-        self.content_payload_key = content_payload_key or self.CONTENT_KEY
-        self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY
-        self.group_payload_key = group_payload_key or self.GROUP_KEY
-        self.vector_name = vector_name or self.VECTOR_NAME
-        self.group_id = group_id
-        self.is_new_collection= is_new_collection
-
-        if embedding_function is not None:
-            warnings.warn(
-                "Using `embedding_function` is deprecated. "
-                "Pass `Embeddings` instance to `embeddings` instead."
-            )
-
-        if not isinstance(embeddings, Embeddings):
-            warnings.warn(
-                "`embeddings` should be an instance of `Embeddings`."
-                "Using `embeddings` as `embedding_function` which is deprecated"
-            )
-            self._embeddings_function = embeddings
-            self._embeddings = None
-
-        self.distance_strategy = distance_strategy.upper()
-
-    @property
-    def embeddings(self) -> Optional[Embeddings]:
-        return self._embeddings
-
-    def add_texts(
-        self,
-        texts: Iterable[str],
-        metadatas: Optional[list[dict]] = None,
-        ids: Optional[Sequence[str]] = None,
-        batch_size: int = 64,
-        **kwargs: Any,
-    ) -> list[str]:
-        """Run more texts through the embeddings and add to the vectorstore.
-
-        Args:
-            texts: Iterable of strings to add to the vectorstore.
-            metadatas: Optional list of metadatas associated with the texts.
-            ids:
-                Optional list of ids to associate with the texts. Ids have to be
-                uuid-like strings.
-            batch_size:
-                How many vectors upload per-request.
-                Default: 64
-            group_id:
-                collection group
-
-        Returns:
-            List of ids from adding the texts into the vectorstore.
-        """
-        added_ids = []
-        for batch_ids, points in self._generate_rest_batches(
-            texts, metadatas, ids, batch_size
-        ):
-            self.client.upsert(
-                collection_name=self.collection_name, points=points
-            )
-            added_ids.extend(batch_ids)
-        # if is new collection, create payload index on group_id
-        if self.is_new_collection:
-            # create payload index
-            self.client.create_payload_index(self.collection_name, self.group_payload_key,
-                                             field_schema=PayloadSchemaType.KEYWORD,
-                                             field_type=PayloadSchemaType.KEYWORD)
-            # creat full text index
-            text_index_params = TextIndexParams(
-                type=TextIndexType.TEXT,
-                tokenizer=TokenizerType.MULTILINGUAL,
-                min_token_len=2,
-                max_token_len=20,
-                lowercase=True
-            )
-            self.client.create_payload_index(self.collection_name, self.content_payload_key,
-                                             field_schema=text_index_params)
-        return added_ids
-
-    @sync_call_fallback
-    async def aadd_texts(
-        self,
-        texts: Iterable[str],
-        metadatas: Optional[list[dict]] = None,
-        ids: Optional[Sequence[str]] = None,
-        batch_size: int = 64,
-        **kwargs: Any,
-    ) -> list[str]:
-        """Run more texts through the embeddings and add to the vectorstore.
-
-        Args:
-            texts: Iterable of strings to add to the vectorstore.
-            metadatas: Optional list of metadatas associated with the texts.
-            ids:
-                Optional list of ids to associate with the texts. Ids have to be
-                uuid-like strings.
-            batch_size:
-                How many vectors upload per-request.
-                Default: 64
-
-        Returns:
-            List of ids from adding the texts into the vectorstore.
-        """
-        from qdrant_client import grpc  # noqa
-        from qdrant_client.conversions.conversion import RestToGrpc
-
-        added_ids = []
-        for batch_ids, points in self._generate_rest_batches(
-            texts, metadatas, ids, batch_size
-        ):
-            await self.client.async_grpc_points.Upsert(
-                grpc.UpsertPoints(
-                    collection_name=self.collection_name,
-                    points=[RestToGrpc.convert_point_struct(point) for point in points],
-                )
-            )
-            added_ids.extend(batch_ids)
-
-        return added_ids
-
-    def similarity_search(
-        self,
-        query: str,
-        k: int = 4,
-        filter: Optional[MetadataFilter] = None,
-        search_params: Optional[common_types.SearchParams] = None,
-        offset: int = 0,
-        score_threshold: Optional[float] = None,
-        consistency: Optional[common_types.ReadConsistency] = None,
-        **kwargs: Any,
-    ) -> list[tuple[Document, float]]:
-        """Return docs most similar to query.
-
-        Args:
-            query: Text to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            filter: Filter by metadata. Defaults to None.
-            search_params: Additional search params
-            offset:
-                Offset of the first result to return.
-                May be used to paginate results.
-                Note: large offset values may cause performance issues.
-            score_threshold:
-                Define a minimal score threshold for the result.
-                If defined, less similar results will not be returned.
-                Score of the returned result might be higher or smaller than the
-                threshold depending on the Distance function used.
-                E.g. for cosine similarity only higher scores will be returned.
-            consistency:
-                Read consistency of the search. Defines how many replicas should be
-                queried before returning the result.
-                Values:
-                - int - number of replicas to query, values should present in all
-                        queried replicas
-                - 'majority' - query all replicas, but return values present in the
-                               majority of replicas
-                - 'quorum' - query the majority of replicas, return values present in
-                             all of them
-                - 'all' - query all replicas, and return values present in all replicas
-
-        Returns:
-            List of Documents most similar to the query.
-        """
-        results = self.similarity_search_with_score(
-            query,
-            k,
-            filter=filter,
-            search_params=search_params,
-            offset=offset,
-            score_threshold=score_threshold,
-            consistency=consistency,
-            **kwargs,
-        )
-        return list(map(itemgetter(0), results))
-
-    @sync_call_fallback
-    async def asimilarity_search(
-        self,
-        query: str,
-        k: int = 4,
-        filter: Optional[MetadataFilter] = None,
-        **kwargs: Any,
-    ) -> list[Document]:
-        """Return docs most similar to query.
-        Args:
-            query: Text to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            filter: Filter by metadata. Defaults to None.
-        Returns:
-            List of Documents most similar to the query.
-        """
-        results = await self.asimilarity_search_with_score(query, k, filter, **kwargs)
-        return list(map(itemgetter(0), results))
-
-    def similarity_search_with_score(
-        self,
-        query: str,
-        k: int = 4,
-        filter: Optional[MetadataFilter] = None,
-        search_params: Optional[common_types.SearchParams] = None,
-        offset: int = 0,
-        score_threshold: Optional[float] = None,
-        consistency: Optional[common_types.ReadConsistency] = None,
-        **kwargs: Any,
-    ) -> list[tuple[Document, float]]:
-        """Return docs most similar to query.
-
-        Args:
-            query: Text to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            filter: Filter by metadata. Defaults to None.
-            search_params: Additional search params
-            offset:
-                Offset of the first result to return.
-                May be used to paginate results.
-                Note: large offset values may cause performance issues.
-            score_threshold:
-                Define a minimal score threshold for the result.
-                If defined, less similar results will not be returned.
-                Score of the returned result might be higher or smaller than the
-                threshold depending on the Distance function used.
-                E.g. for cosine similarity only higher scores will be returned.
-            consistency:
-                Read consistency of the search. Defines how many replicas should be
-                queried before returning the result.
-                Values:
-                - int - number of replicas to query, values should present in all
-                        queried replicas
-                - 'majority' - query all replicas, but return values present in the
-                               majority of replicas
-                - 'quorum' - query the majority of replicas, return values present in
-                             all of them
-                - 'all' - query all replicas, and return values present in all replicas
-
-        Returns:
-            List of documents most similar to the query text and distance for each.
-        """
-        return self.similarity_search_with_score_by_vector(
-            self._embed_query(query),
-            k,
-            filter=filter,
-            search_params=search_params,
-            offset=offset,
-            score_threshold=score_threshold,
-            consistency=consistency,
-            **kwargs,
-        )
-
-    @sync_call_fallback
-    async def asimilarity_search_with_score(
-        self,
-        query: str,
-        k: int = 4,
-        filter: Optional[MetadataFilter] = None,
-        search_params: Optional[common_types.SearchParams] = None,
-        offset: int = 0,
-        score_threshold: Optional[float] = None,
-        consistency: Optional[common_types.ReadConsistency] = None,
-        **kwargs: Any,
-    ) -> list[tuple[Document, float]]:
-        """Return docs most similar to query.
-
-        Args:
-            query: Text to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            filter: Filter by metadata. Defaults to None.
-            search_params: Additional search params
-            offset:
-                Offset of the first result to return.
-                May be used to paginate results.
-                Note: large offset values may cause performance issues.
-            score_threshold:
-                Define a minimal score threshold for the result.
-                If defined, less similar results will not be returned.
-                Score of the returned result might be higher or smaller than the
-                threshold depending on the Distance function used.
-                E.g. for cosine similarity only higher scores will be returned.
-            consistency:
-                Read consistency of the search. Defines how many replicas should be
-                queried before returning the result.
-                Values:
-                - int - number of replicas to query, values should present in all
-                        queried replicas
-                - 'majority' - query all replicas, but return values present in the
-                               majority of replicas
-                - 'quorum' - query the majority of replicas, return values present in
-                             all of them
-                - 'all' - query all replicas, and return values present in all replicas
-
-        Returns:
-            List of documents most similar to the query text and distance for each.
-        """
-        return await self.asimilarity_search_with_score_by_vector(
-            self._embed_query(query),
-            k,
-            filter=filter,
-            search_params=search_params,
-            offset=offset,
-            score_threshold=score_threshold,
-            consistency=consistency,
-            **kwargs,
-        )
-
-    def similarity_search_by_vector(
-        self,
-        embedding: list[float],
-        k: int = 4,
-        filter: Optional[MetadataFilter] = None,
-        search_params: Optional[common_types.SearchParams] = None,
-        offset: int = 0,
-        score_threshold: Optional[float] = None,
-        consistency: Optional[common_types.ReadConsistency] = None,
-        **kwargs: Any,
-    ) -> list[Document]:
-        """Return docs most similar to embedding vector.
-
-        Args:
-            embedding: Embedding vector to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            filter: Filter by metadata. Defaults to None.
-            search_params: Additional search params
-            offset:
-                Offset of the first result to return.
-                May be used to paginate results.
-                Note: large offset values may cause performance issues.
-            score_threshold:
-                Define a minimal score threshold for the result.
-                If defined, less similar results will not be returned.
-                Score of the returned result might be higher or smaller than the
-                threshold depending on the Distance function used.
-                E.g. for cosine similarity only higher scores will be returned.
-            consistency:
-                Read consistency of the search. Defines how many replicas should be
-                queried before returning the result.
-                Values:
-                - int - number of replicas to query, values should present in all
-                        queried replicas
-                - 'majority' - query all replicas, but return values present in the
-                               majority of replicas
-                - 'quorum' - query the majority of replicas, return values present in
-                             all of them
-                - 'all' - query all replicas, and return values present in all replicas
-
-        Returns:
-            List of Documents most similar to the query.
-        """
-        results = self.similarity_search_with_score_by_vector(
-            embedding,
-            k,
-            filter=filter,
-            search_params=search_params,
-            offset=offset,
-            score_threshold=score_threshold,
-            consistency=consistency,
-            **kwargs,
-        )
-        return list(map(itemgetter(0), results))
-
-    @sync_call_fallback
-    async def asimilarity_search_by_vector(
-        self,
-        embedding: list[float],
-        k: int = 4,
-        filter: Optional[MetadataFilter] = None,
-        search_params: Optional[common_types.SearchParams] = None,
-        offset: int = 0,
-        score_threshold: Optional[float] = None,
-        consistency: Optional[common_types.ReadConsistency] = None,
-        **kwargs: Any,
-    ) -> list[Document]:
-        """Return docs most similar to embedding vector.
-
-        Args:
-            embedding: Embedding vector to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            filter: Filter by metadata. Defaults to None.
-            search_params: Additional search params
-            offset:
-                Offset of the first result to return.
-                May be used to paginate results.
-                Note: large offset values may cause performance issues.
-            score_threshold:
-                Define a minimal score threshold for the result.
-                If defined, less similar results will not be returned.
-                Score of the returned result might be higher or smaller than the
-                threshold depending on the Distance function used.
-                E.g. for cosine similarity only higher scores will be returned.
-            consistency:
-                Read consistency of the search. Defines how many replicas should be
-                queried before returning the result.
-                Values:
-                - int - number of replicas to query, values should present in all
-                        queried replicas
-                - 'majority' - query all replicas, but return values present in the
-                               majority of replicas
-                - 'quorum' - query the majority of replicas, return values present in
-                             all of them
-                - 'all' - query all replicas, and return values present in all replicas
-
-        Returns:
-            List of Documents most similar to the query.
-        """
-        results = await self.asimilarity_search_with_score_by_vector(
-            embedding,
-            k,
-            filter=filter,
-            search_params=search_params,
-            offset=offset,
-            score_threshold=score_threshold,
-            consistency=consistency,
-            **kwargs,
-        )
-        return list(map(itemgetter(0), results))
-
-    def similarity_search_with_score_by_vector(
-        self,
-        embedding: list[float],
-        k: int = 4,
-        filter: Optional[MetadataFilter] = None,
-        search_params: Optional[common_types.SearchParams] = None,
-        offset: int = 0,
-        score_threshold: Optional[float] = None,
-        consistency: Optional[common_types.ReadConsistency] = None,
-        **kwargs: Any,
-    ) -> list[tuple[Document, float]]:
-        """Return docs most similar to embedding vector.
-
-        Args:
-            embedding: Embedding vector to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            filter: Filter by metadata. Defaults to None.
-            search_params: Additional search params
-            offset:
-                Offset of the first result to return.
-                May be used to paginate results.
-                Note: large offset values may cause performance issues.
-            score_threshold:
-                Define a minimal score threshold for the result.
-                If defined, less similar results will not be returned.
-                Score of the returned result might be higher or smaller than the
-                threshold depending on the Distance function used.
-                E.g. for cosine similarity only higher scores will be returned.
-            consistency:
-                Read consistency of the search. Defines how many replicas should be
-                queried before returning the result.
-                Values:
-                - int - number of replicas to query, values should present in all
-                        queried replicas
-                - 'majority' - query all replicas, but return values present in the
-                               majority of replicas
-                - 'quorum' - query the majority of replicas, return values present in
-                             all of them
-                - 'all' - query all replicas, and return values present in all replicas
-
-        Returns:
-            List of documents most similar to the query text and distance for each.
-        """
-        if filter is not None and isinstance(filter, dict):
-            warnings.warn(
-                "Using dict as a `filter` is deprecated. Please use qdrant-client "
-                "filters directly: "
-                "https://qdrant.tech/documentation/concepts/filtering/",
-                DeprecationWarning,
-            )
-            qdrant_filter = self._qdrant_filter_from_dict(filter)
-        else:
-            qdrant_filter = filter
-
-        query_vector = embedding
-        if self.vector_name is not None:
-            query_vector = (self.vector_name, embedding)  # type: ignore[assignment]
-
-        results = self.client.search(
-            collection_name=self.collection_name,
-            query_vector=query_vector,
-            query_filter=qdrant_filter,
-            search_params=search_params,
-            limit=k,
-            offset=offset,
-            with_payload=True,
-            with_vectors=True,
-            score_threshold=score_threshold,
-            consistency=consistency,
-            **kwargs,
-        )
-        return [
-            (
-                self._document_from_scored_point(
-                    result, self.content_payload_key, self.metadata_payload_key
-                ),
-                result.score,
-            )
-            for result in results
-        ]
-
-    def similarity_search_by_bm25(
-        self,
-        filter: Optional[MetadataFilter] = None,
-        k: int = 4
-    ) -> list[Document]:
-        """Return docs most similar by bm25.
-
-        Args:
-            embedding: Embedding vector to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            filter: Filter by metadata. Defaults to None.
-            search_params: Additional search params
-        Returns:
-            List of documents most similar to the query text and distance for each.
-        """
-        response = self.client.scroll(
-            collection_name=self.collection_name,
-            scroll_filter=filter,
-            limit=k,
-            with_payload=True,
-            with_vectors=True
-
-        )
-        results = response[0]
-        documents = []
-        for result in results:
-            if result:
-                documents.append(self._document_from_scored_point(
-                        result, self.content_payload_key, self.metadata_payload_key
-                    ))
-
-        return documents
-
-    @sync_call_fallback
-    async def asimilarity_search_with_score_by_vector(
-        self,
-        embedding: list[float],
-        k: int = 4,
-        filter: Optional[MetadataFilter] = None,
-        search_params: Optional[common_types.SearchParams] = None,
-        offset: int = 0,
-        score_threshold: Optional[float] = None,
-        consistency: Optional[common_types.ReadConsistency] = None,
-        **kwargs: Any,
-    ) -> list[tuple[Document, float]]:
-        """Return docs most similar to embedding vector.
-
-        Args:
-            embedding: Embedding vector to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            filter: Filter by metadata. Defaults to None.
-            search_params: Additional search params
-            offset:
-                Offset of the first result to return.
-                May be used to paginate results.
-                Note: large offset values may cause performance issues.
-            score_threshold:
-                Define a minimal score threshold for the result.
-                If defined, less similar results will not be returned.
-                Score of the returned result might be higher or smaller than the
-                threshold depending on the Distance function used.
-                E.g. for cosine similarity only higher scores will be returned.
-            consistency:
-                Read consistency of the search. Defines how many replicas should be
-                queried before returning the result.
-                Values:
-                - int - number of replicas to query, values should present in all
-                        queried replicas
-                - 'majority' - query all replicas, but return values present in the
-                               majority of replicas
-                - 'quorum' - query the majority of replicas, return values present in
-                             all of them
-                - 'all' - query all replicas, and return values present in all replicas
-
-        Returns:
-            List of documents most similar to the query text and distance for each.
-        """
-        from qdrant_client import grpc  # noqa
-        from qdrant_client.conversions.conversion import RestToGrpc
-        from qdrant_client.http import models as rest
-
-        if filter is not None and isinstance(filter, dict):
-            warnings.warn(
-                "Using dict as a `filter` is deprecated. Please use qdrant-client "
-                "filters directly: "
-                "https://qdrant.tech/documentation/concepts/filtering/",
-                DeprecationWarning,
-            )
-            qdrant_filter = self._qdrant_filter_from_dict(filter)
-        else:
-            qdrant_filter = filter
-
-        if qdrant_filter is not None and isinstance(qdrant_filter, rest.Filter):
-            qdrant_filter = RestToGrpc.convert_filter(qdrant_filter)
-
-        response = await self.client.async_grpc_points.Search(
-            grpc.SearchPoints(
-                collection_name=self.collection_name,
-                vector_name=self.vector_name,
-                vector=embedding,
-                filter=qdrant_filter,
-                params=search_params,
-                limit=k,
-                offset=offset,
-                with_payload=grpc.WithPayloadSelector(enable=True),
-                with_vectors=grpc.WithVectorsSelector(enable=False),
-                score_threshold=score_threshold,
-                read_consistency=consistency,
-                **kwargs,
-            )
-        )
-
-        return [
-            (
-                self._document_from_scored_point_grpc(
-                    result, self.content_payload_key, self.metadata_payload_key
-                ),
-                result.score,
-            )
-            for result in response.result
-        ]
-
-    def max_marginal_relevance_search(
-        self,
-        query: str,
-        k: int = 4,
-        fetch_k: int = 20,
-        lambda_mult: float = 0.5,
-        **kwargs: Any,
-    ) -> list[Document]:
-        """Return docs selected using the maximal marginal relevance.
-
-        Maximal marginal relevance optimizes for similarity to query AND diversity
-        among selected documents.
-
-        Args:
-            query: Text to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            fetch_k: Number of Documents to fetch to pass to MMR algorithm.
-                     Defaults to 20.
-            lambda_mult: Number between 0 and 1 that determines the degree
-                        of diversity among the results with 0 corresponding
-                        to maximum diversity and 1 to minimum diversity.
-                        Defaults to 0.5.
-        Returns:
-            List of Documents selected by maximal marginal relevance.
-        """
-        query_embedding = self._embed_query(query)
-        return self.max_marginal_relevance_search_by_vector(
-            query_embedding, k, fetch_k, lambda_mult, **kwargs
-        )
-
-    @sync_call_fallback
-    async def amax_marginal_relevance_search(
-        self,
-        query: str,
-        k: int = 4,
-        fetch_k: int = 20,
-        lambda_mult: float = 0.5,
-        **kwargs: Any,
-    ) -> list[Document]:
-        """Return docs selected using the maximal marginal relevance.
-
-        Maximal marginal relevance optimizes for similarity to query AND diversity
-        among selected documents.
-
-        Args:
-            query: Text to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            fetch_k: Number of Documents to fetch to pass to MMR algorithm.
-                     Defaults to 20.
-            lambda_mult: Number between 0 and 1 that determines the degree
-                        of diversity among the results with 0 corresponding
-                        to maximum diversity and 1 to minimum diversity.
-                        Defaults to 0.5.
-        Returns:
-            List of Documents selected by maximal marginal relevance.
-        """
-        query_embedding = self._embed_query(query)
-        return await self.amax_marginal_relevance_search_by_vector(
-            query_embedding, k, fetch_k, lambda_mult, **kwargs
-        )
-
-    def max_marginal_relevance_search_by_vector(
-        self,
-        embedding: list[float],
-        k: int = 4,
-        fetch_k: int = 20,
-        lambda_mult: float = 0.5,
-        **kwargs: Any,
-    ) -> list[Document]:
-        """Return docs selected using the maximal marginal relevance.
-
-        Maximal marginal relevance optimizes for similarity to query AND diversity
-        among selected documents.
-
-        Args:
-            embedding: Embedding to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            fetch_k: Number of Documents to fetch to pass to MMR algorithm.
-            lambda_mult: Number between 0 and 1 that determines the degree
-                        of diversity among the results with 0 corresponding
-                        to maximum diversity and 1 to minimum diversity.
-                        Defaults to 0.5.
-        Returns:
-            List of Documents selected by maximal marginal relevance.
-        """
-        results = self.max_marginal_relevance_search_with_score_by_vector(
-            embedding=embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, **kwargs
-        )
-        return list(map(itemgetter(0), results))
-
-    @sync_call_fallback
-    async def amax_marginal_relevance_search_by_vector(
-        self,
-        embedding: list[float],
-        k: int = 4,
-        fetch_k: int = 20,
-        lambda_mult: float = 0.5,
-        **kwargs: Any,
-    ) -> list[Document]:
-        """Return docs selected using the maximal marginal relevance.
-        Maximal marginal relevance optimizes for similarity to query AND diversity
-        among selected documents.
-        Args:
-            query: Text to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            fetch_k: Number of Documents to fetch to pass to MMR algorithm.
-                     Defaults to 20.
-            lambda_mult: Number between 0 and 1 that determines the degree
-                        of diversity among the results with 0 corresponding
-                        to maximum diversity and 1 to minimum diversity.
-                        Defaults to 0.5.
-        Returns:
-            List of Documents selected by maximal marginal relevance and distance for
-            each.
-        """
-        results = await self.amax_marginal_relevance_search_with_score_by_vector(
-            embedding, k, fetch_k, lambda_mult, **kwargs
-        )
-        return list(map(itemgetter(0), results))
-
-    def max_marginal_relevance_search_with_score_by_vector(
-        self,
-        embedding: list[float],
-        k: int = 4,
-        fetch_k: int = 20,
-        lambda_mult: float = 0.5,
-        **kwargs: Any,
-    ) -> list[tuple[Document, float]]:
-        """Return docs selected using the maximal marginal relevance.
-        Maximal marginal relevance optimizes for similarity to query AND diversity
-        among selected documents.
-        Args:
-            query: Text to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            fetch_k: Number of Documents to fetch to pass to MMR algorithm.
-                     Defaults to 20.
-            lambda_mult: Number between 0 and 1 that determines the degree
-                        of diversity among the results with 0 corresponding
-                        to maximum diversity and 1 to minimum diversity.
-                        Defaults to 0.5.
-        Returns:
-            List of Documents selected by maximal marginal relevance and distance for
-            each.
-        """
-        query_vector = embedding
-        if self.vector_name is not None:
-            query_vector = (self.vector_name, query_vector)  # type: ignore[assignment]
-
-        results = self.client.search(
-            collection_name=self.collection_name,
-            query_vector=query_vector,
-            with_payload=True,
-            with_vectors=True,
-            limit=fetch_k,
-        )
-        embeddings = [
-            result.vector.get(self.vector_name)  # type: ignore[index, union-attr]
-            if self.vector_name is not None
-            else result.vector
-            for result in results
-        ]
-        mmr_selected = maximal_marginal_relevance(
-            np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
-        )
-        return [
-            (
-                self._document_from_scored_point(
-                    results[i], self.content_payload_key, self.metadata_payload_key
-                ),
-                results[i].score,
-            )
-            for i in mmr_selected
-        ]
-
-    @sync_call_fallback
-    async def amax_marginal_relevance_search_with_score_by_vector(
-        self,
-        embedding: list[float],
-        k: int = 4,
-        fetch_k: int = 20,
-        lambda_mult: float = 0.5,
-        **kwargs: Any,
-    ) -> list[tuple[Document, float]]:
-        """Return docs selected using the maximal marginal relevance.
-        Maximal marginal relevance optimizes for similarity to query AND diversity
-        among selected documents.
-        Args:
-            query: Text to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            fetch_k: Number of Documents to fetch to pass to MMR algorithm.
-                     Defaults to 20.
-            lambda_mult: Number between 0 and 1 that determines the degree
-                        of diversity among the results with 0 corresponding
-                        to maximum diversity and 1 to minimum diversity.
-                        Defaults to 0.5.
-        Returns:
-            List of Documents selected by maximal marginal relevance and distance for
-            each.
-        """
-        from qdrant_client import grpc  # noqa
-        from qdrant_client.conversions.conversion import GrpcToRest
-
-        response = await self.client.async_grpc_points.Search(
-            grpc.SearchPoints(
-                collection_name=self.collection_name,
-                vector_name=self.vector_name,
-                vector=embedding,
-                with_payload=grpc.WithPayloadSelector(enable=True),
-                with_vectors=grpc.WithVectorsSelector(enable=True),
-                limit=fetch_k,
-            )
-        )
-        results = [
-            GrpcToRest.convert_vectors(result.vectors) for result in response.result
-        ]
-        embeddings: list[list[float]] = [
-            result.get(self.vector_name)  # type: ignore
-            if isinstance(result, dict)
-            else result
-            for result in results
-        ]
-        mmr_selected: list[int] = maximal_marginal_relevance(
-            np.array(embedding),
-            embeddings,
-            k=k,
-            lambda_mult=lambda_mult,
-        )
-        return [
-            (
-                self._document_from_scored_point_grpc(
-                    response.result[i],
-                    self.content_payload_key,
-                    self.metadata_payload_key,
-                ),
-                response.result[i].score,
-            )
-            for i in mmr_selected
-        ]
-
-    def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> Optional[bool]:
-        """Delete by vector ID or other criteria.
-
-        Args:
-            ids: List of ids to delete.
-            **kwargs: Other keyword arguments that subclasses might use.
-
-        Returns:
-            Optional[bool]: True if deletion is successful,
-            False otherwise, None if not implemented.
-        """
-        from qdrant_client.http import models as rest
-
-        result = self.client.delete(
-            collection_name=self.collection_name,
-            points_selector=ids,
-        )
-        return result.status == rest.UpdateStatus.COMPLETED
-
-    @classmethod
-    def from_texts(
-        cls: type[Qdrant],
-        texts: list[str],
-        embedding: Embeddings,
-        metadatas: Optional[list[dict]] = None,
-        ids: Optional[Sequence[str]] = None,
-        location: Optional[str] = None,
-        url: Optional[str] = None,
-        port: Optional[int] = 6333,
-        grpc_port: int = 6334,
-        prefer_grpc: bool = False,
-        https: Optional[bool] = None,
-        api_key: Optional[str] = None,
-        prefix: Optional[str] = None,
-        timeout: Optional[float] = None,
-        host: Optional[str] = None,
-        path: Optional[str] = None,
-        collection_name: Optional[str] = None,
-        distance_func: str = "Cosine",
-        content_payload_key: str = CONTENT_KEY,
-        metadata_payload_key: str = METADATA_KEY,
-        group_payload_key: str = GROUP_KEY,
-        group_id: str = None,
-        vector_name: Optional[str] = VECTOR_NAME,
-        batch_size: int = 64,
-        shard_number: Optional[int] = None,
-        replication_factor: Optional[int] = None,
-        write_consistency_factor: Optional[int] = None,
-        on_disk_payload: Optional[bool] = None,
-        hnsw_config: Optional[common_types.HnswConfigDiff] = None,
-        optimizers_config: Optional[common_types.OptimizersConfigDiff] = None,
-        wal_config: Optional[common_types.WalConfigDiff] = None,
-        quantization_config: Optional[common_types.QuantizationConfig] = None,
-        init_from: Optional[common_types.InitFrom] = None,
-        force_recreate: bool = False,
-        **kwargs: Any,
-    ) -> Qdrant:
-        """Construct Qdrant wrapper from a list of texts.
-
-        Args:
-            texts: A list of texts to be indexed in Qdrant.
-            embedding: A subclass of `Embeddings`, responsible for text vectorization.
-            metadatas:
-                An optional list of metadata. If provided it has to be of the same
-                length as a list of texts.
-            ids:
-                Optional list of ids to associate with the texts. Ids have to be
-                uuid-like strings.
-            location:
-                If `:memory:` - use in-memory Qdrant instance.
-                If `str` - use it as a `url` parameter.
-                If `None` - fallback to relying on `host` and `port` parameters.
-            url: either host or str of "Optional[scheme], host, Optional[port],
-                Optional[prefix]". Default: `None`
-            port: Port of the REST API interface. Default: 6333
-            grpc_port: Port of the gRPC interface. Default: 6334
-            prefer_grpc:
-                If true - use gPRC interface whenever possible in custom methods.
-                Default: False
-            https: If true - use HTTPS(SSL) protocol. Default: None
-            api_key: API key for authentication in Qdrant Cloud. Default: None
-            prefix:
-                If not None - add prefix to the REST URL path.
-                Example: service/v1 will result in
-                    http://localhost:6333/service/v1/{qdrant-endpoint} for REST API.
-                Default: None
-            timeout:
-                Timeout for REST and gRPC API requests.
-                Default: 5.0 seconds for REST and unlimited for gRPC
-            host:
-                Host name of Qdrant service. If url and host are None, set to
-                'localhost'. Default: None
-            path:
-                Path in which the vectors will be stored while using local mode.
-                Default: None
-            collection_name:
-                Name of the Qdrant collection to be used. If not provided,
-                it will be created randomly. Default: None
-            distance_func:
-                Distance function. One of: "Cosine" / "Euclid" / "Dot".
-                Default: "Cosine"
-            content_payload_key:
-                A payload key used to store the content of the document.
-                Default: "page_content"
-            metadata_payload_key:
-                A payload key used to store the metadata of the document.
-                Default: "metadata"
-            group_payload_key:
-                A payload key used to store the content of the document.
-                Default: "group_id"
-            group_id:
-                collection group id
-            vector_name:
-                Name of the vector to be used internally in Qdrant.
-                Default: None
-            batch_size:
-                How many vectors upload per-request.
-                Default: 64
-            shard_number: Number of shards in collection. Default is 1, minimum is 1.
-            replication_factor:
-                Replication factor for collection. Default is 1, minimum is 1.
-                Defines how many copies of each shard will be created.
-                Have effect only in distributed mode.
-            write_consistency_factor:
-                Write consistency factor for collection. Default is 1, minimum is 1.
-                Defines how many replicas should apply the operation for us to consider
-                it successful. Increasing this number will make the collection more
-                resilient to inconsistencies, but will also make it fail if not enough
-                replicas are available.
-                Does not have any performance impact.
-                Have effect only in distributed mode.
-            on_disk_payload:
-                If true - point`s payload will not be stored in memory.
-                It will be read from the disk every time it is requested.
-                This setting saves RAM by (slightly) increasing the response time.
-                Note: those payload values that are involved in filtering and are
-                indexed - remain in RAM.
-            hnsw_config: Params for HNSW index
-            optimizers_config: Params for optimizer
-            wal_config: Params for Write-Ahead-Log
-            quantization_config:
-                Params for quantization, if None - quantization will be disabled
-            init_from:
-                Use data stored in another collection to initialize this collection
-            force_recreate:
-                Force recreating the collection
-            **kwargs:
-                Additional arguments passed directly into REST client initialization
-
-        This is a user-friendly interface that:
-        1. Creates embeddings, one for each text
-        2. Initializes the Qdrant database as an in-memory docstore by default
-           (and overridable to a remote docstore)
-        3. Adds the text embeddings to the Qdrant database
-
-        This is intended to be a quick way to get started.
-
-        Example:
-            .. code-block:: python
-
-                from langchain import Qdrant
-                from langchain.embeddings import OpenAIEmbeddings
-                embeddings = OpenAIEmbeddings()
-                qdrant = Qdrant.from_texts(texts, embeddings, "localhost")
-        """
-        qdrant = cls._construct_instance(
-            texts,
-            embedding,
-            metadatas,
-            ids,
-            location,
-            url,
-            port,
-            grpc_port,
-            prefer_grpc,
-            https,
-            api_key,
-            prefix,
-            timeout,
-            host,
-            path,
-            collection_name,
-            distance_func,
-            content_payload_key,
-            metadata_payload_key,
-            group_payload_key,
-            group_id,
-            vector_name,
-            shard_number,
-            replication_factor,
-            write_consistency_factor,
-            on_disk_payload,
-            hnsw_config,
-            optimizers_config,
-            wal_config,
-            quantization_config,
-            init_from,
-            force_recreate,
-            **kwargs,
-        )
-        qdrant.add_texts(texts, metadatas, ids, batch_size)
-        return qdrant
-
-    @classmethod
-    @sync_call_fallback
-    async def afrom_texts(
-        cls: type[Qdrant],
-        texts: list[str],
-        embedding: Embeddings,
-        metadatas: Optional[list[dict]] = None,
-        ids: Optional[Sequence[str]] = None,
-        location: Optional[str] = None,
-        url: Optional[str] = None,
-        port: Optional[int] = 6333,
-        grpc_port: int = 6334,
-        prefer_grpc: bool = False,
-        https: Optional[bool] = None,
-        api_key: Optional[str] = None,
-        prefix: Optional[str] = None,
-        timeout: Optional[float] = None,
-        host: Optional[str] = None,
-        path: Optional[str] = None,
-        collection_name: Optional[str] = None,
-        distance_func: str = "Cosine",
-        content_payload_key: str = CONTENT_KEY,
-        metadata_payload_key: str = METADATA_KEY,
-        vector_name: Optional[str] = VECTOR_NAME,
-        batch_size: int = 64,
-        shard_number: Optional[int] = None,
-        replication_factor: Optional[int] = None,
-        write_consistency_factor: Optional[int] = None,
-        on_disk_payload: Optional[bool] = None,
-        hnsw_config: Optional[common_types.HnswConfigDiff] = None,
-        optimizers_config: Optional[common_types.OptimizersConfigDiff] = None,
-        wal_config: Optional[common_types.WalConfigDiff] = None,
-        quantization_config: Optional[common_types.QuantizationConfig] = None,
-        init_from: Optional[common_types.InitFrom] = None,
-        force_recreate: bool = False,
-        **kwargs: Any,
-    ) -> Qdrant:
-        """Construct Qdrant wrapper from a list of texts.
-
-        Args:
-            texts: A list of texts to be indexed in Qdrant.
-            embedding: A subclass of `Embeddings`, responsible for text vectorization.
-            metadatas:
-                An optional list of metadata. If provided it has to be of the same
-                length as a list of texts.
-            ids:
-                Optional list of ids to associate with the texts. Ids have to be
-                uuid-like strings.
-            location:
-                If `:memory:` - use in-memory Qdrant instance.
-                If `str` - use it as a `url` parameter.
-                If `None` - fallback to relying on `host` and `port` parameters.
-            url: either host or str of "Optional[scheme], host, Optional[port],
-                Optional[prefix]". Default: `None`
-            port: Port of the REST API interface. Default: 6333
-            grpc_port: Port of the gRPC interface. Default: 6334
-            prefer_grpc:
-                If true - use gPRC interface whenever possible in custom methods.
-                Default: False
-            https: If true - use HTTPS(SSL) protocol. Default: None
-            api_key: API key for authentication in Qdrant Cloud. Default: None
-            prefix:
-                If not None - add prefix to the REST URL path.
-                Example: service/v1 will result in
-                    http://localhost:6333/service/v1/{qdrant-endpoint} for REST API.
-                Default: None
-            timeout:
-                Timeout for REST and gRPC API requests.
-                Default: 5.0 seconds for REST and unlimited for gRPC
-            host:
-                Host name of Qdrant service. If url and host are None, set to
-                'localhost'. Default: None
-            path:
-                Path in which the vectors will be stored while using local mode.
-                Default: None
-            collection_name:
-                Name of the Qdrant collection to be used. If not provided,
-                it will be created randomly. Default: None
-            distance_func:
-                Distance function. One of: "Cosine" / "Euclid" / "Dot".
-                Default: "Cosine"
-            content_payload_key:
-                A payload key used to store the content of the document.
-                Default: "page_content"
-            metadata_payload_key:
-                A payload key used to store the metadata of the document.
-                Default: "metadata"
-            vector_name:
-                Name of the vector to be used internally in Qdrant.
-                Default: None
-            batch_size:
-                How many vectors upload per-request.
-                Default: 64
-            shard_number: Number of shards in collection. Default is 1, minimum is 1.
-            replication_factor:
-                Replication factor for collection. Default is 1, minimum is 1.
-                Defines how many copies of each shard will be created.
-                Have effect only in distributed mode.
-            write_consistency_factor:
-                Write consistency factor for collection. Default is 1, minimum is 1.
-                Defines how many replicas should apply the operation for us to consider
-                it successful. Increasing this number will make the collection more
-                resilient to inconsistencies, but will also make it fail if not enough
-                replicas are available.
-                Does not have any performance impact.
-                Have effect only in distributed mode.
-            on_disk_payload:
-                If true - point`s payload will not be stored in memory.
-                It will be read from the disk every time it is requested.
-                This setting saves RAM by (slightly) increasing the response time.
-                Note: those payload values that are involved in filtering and are
-                indexed - remain in RAM.
-            hnsw_config: Params for HNSW index
-            optimizers_config: Params for optimizer
-            wal_config: Params for Write-Ahead-Log
-            quantization_config:
-                Params for quantization, if None - quantization will be disabled
-            init_from:
-                Use data stored in another collection to initialize this collection
-            force_recreate:
-                Force recreating the collection
-            **kwargs:
-                Additional arguments passed directly into REST client initialization
-
-        This is a user-friendly interface that:
-        1. Creates embeddings, one for each text
-        2. Initializes the Qdrant database as an in-memory docstore by default
-           (and overridable to a remote docstore)
-        3. Adds the text embeddings to the Qdrant database
-
-        This is intended to be a quick way to get started.
-
-        Example:
-            .. code-block:: python
-
-                from langchain import Qdrant
-                from langchain.embeddings import OpenAIEmbeddings
-                embeddings = OpenAIEmbeddings()
-                qdrant = await Qdrant.afrom_texts(texts, embeddings, "localhost")
-        """
-        qdrant = cls._construct_instance(
-            texts,
-            embedding,
-            metadatas,
-            ids,
-            location,
-            url,
-            port,
-            grpc_port,
-            prefer_grpc,
-            https,
-            api_key,
-            prefix,
-            timeout,
-            host,
-            path,
-            collection_name,
-            distance_func,
-            content_payload_key,
-            metadata_payload_key,
-            vector_name,
-            shard_number,
-            replication_factor,
-            write_consistency_factor,
-            on_disk_payload,
-            hnsw_config,
-            optimizers_config,
-            wal_config,
-            quantization_config,
-            init_from,
-            force_recreate,
-            **kwargs,
-        )
-        await qdrant.aadd_texts(texts, metadatas, ids, batch_size)
-        return qdrant
-
-    @classmethod
-    def _construct_instance(
-        cls: type[Qdrant],
-        texts: list[str],
-        embedding: Embeddings,
-        metadatas: Optional[list[dict]] = None,
-        ids: Optional[Sequence[str]] = None,
-        location: Optional[str] = None,
-        url: Optional[str] = None,
-        port: Optional[int] = 6333,
-        grpc_port: int = 6334,
-        prefer_grpc: bool = False,
-        https: Optional[bool] = None,
-        api_key: Optional[str] = None,
-        prefix: Optional[str] = None,
-        timeout: Optional[float] = None,
-        host: Optional[str] = None,
-        path: Optional[str] = None,
-        collection_name: Optional[str] = None,
-        distance_func: str = "Cosine",
-        content_payload_key: str = CONTENT_KEY,
-        metadata_payload_key: str = METADATA_KEY,
-        group_payload_key: str = GROUP_KEY,
-        group_id: str = None,
-        vector_name: Optional[str] = VECTOR_NAME,
-        shard_number: Optional[int] = None,
-        replication_factor: Optional[int] = None,
-        write_consistency_factor: Optional[int] = None,
-        on_disk_payload: Optional[bool] = None,
-        hnsw_config: Optional[common_types.HnswConfigDiff] = None,
-        optimizers_config: Optional[common_types.OptimizersConfigDiff] = None,
-        wal_config: Optional[common_types.WalConfigDiff] = None,
-        quantization_config: Optional[common_types.QuantizationConfig] = None,
-        init_from: Optional[common_types.InitFrom] = None,
-        force_recreate: bool = False,
-        **kwargs: Any,
-    ) -> Qdrant:
-        try:
-            import qdrant_client
-        except ImportError:
-            raise ValueError(
-                "Could not import qdrant-client python package. "
-                "Please install it with `pip install qdrant-client`."
-            )
-        from qdrant_client.http import models as rest
-
-        # Just do a single quick embedding to get vector size
-        partial_embeddings = embedding.embed_documents(texts[:1])
-        vector_size = len(partial_embeddings[0])
-        collection_name = collection_name or uuid.uuid4().hex
-        distance_func = distance_func.upper()
-        is_new_collection = False
-        client = qdrant_client.QdrantClient(
-            location=location,
-            url=url,
-            port=port,
-            grpc_port=grpc_port,
-            prefer_grpc=prefer_grpc,
-            https=https,
-            api_key=api_key,
-            prefix=prefix,
-            timeout=timeout,
-            host=host,
-            path=path,
-            **kwargs,
-        )
-        all_collection_name = []
-        collections_response = client.get_collections()
-        collection_list = collections_response.collections
-        for collection in collection_list:
-            all_collection_name.append(collection.name)
-        if collection_name not in all_collection_name:
-            vectors_config = rest.VectorParams(
-                size=vector_size,
-                distance=rest.Distance[distance_func],
-            )
-
-            # If vector name was provided, we're going to use the named vectors feature
-            # with just a single vector.
-            if vector_name is not None:
-                vectors_config = {  # type: ignore[assignment]
-                    vector_name: vectors_config,
-                }
-
-            client.recreate_collection(
-                collection_name=collection_name,
-                vectors_config=vectors_config,
-                shard_number=shard_number,
-                replication_factor=replication_factor,
-                write_consistency_factor=write_consistency_factor,
-                on_disk_payload=on_disk_payload,
-                hnsw_config=hnsw_config,
-                optimizers_config=optimizers_config,
-                wal_config=wal_config,
-                quantization_config=quantization_config,
-                init_from=init_from,
-                timeout=int(timeout),  # type: ignore[arg-type]
-            )
-            is_new_collection = True
-        if force_recreate:
-            raise ValueError
-
-        # Get the vector configuration of the existing collection and vector, if it
-        # was specified. If the old configuration does not match the current one,
-        # an exception is being thrown.
-        collection_info = client.get_collection(collection_name=collection_name)
-        current_vector_config = collection_info.config.params.vectors
-        if isinstance(current_vector_config, dict) and vector_name is not None:
-            if vector_name not in current_vector_config:
-                raise QdrantException(
-                    f"Existing Qdrant collection {collection_name} does not "
-                    f"contain vector named {vector_name}. Did you mean one of the "
-                    f"existing vectors: {', '.join(current_vector_config.keys())}? "
-                    f"If you want to recreate the collection, set `force_recreate` "
-                    f"parameter to `True`."
-                )
-            current_vector_config = current_vector_config.get(
-                vector_name
-            )  # type: ignore[assignment]
-        elif isinstance(current_vector_config, dict) and vector_name is None:
-            raise QdrantException(
-                f"Existing Qdrant collection {collection_name} uses named vectors. "
-                f"If you want to reuse it, please set `vector_name` to any of the "
-                f"existing named vectors: "
-                f"{', '.join(current_vector_config.keys())}."  # noqa
-                f"If you want to recreate the collection, set `force_recreate` "
-                f"parameter to `True`."
-            )
-        elif (
-                not isinstance(current_vector_config, dict) and vector_name is not None
-        ):
-            raise QdrantException(
-                f"Existing Qdrant collection {collection_name} doesn't use named "
-                f"vectors. If you want to reuse it, please set `vector_name` to "
-                f"`None`. If you want to recreate the collection, set "
-                f"`force_recreate` parameter to `True`."
-            )
-
-        # Check if the vector configuration has the same dimensionality.
-        if current_vector_config.size != vector_size:  # type: ignore[union-attr]
-            raise QdrantException(
-                f"Existing Qdrant collection is configured for vectors with "
-                f"{current_vector_config.size} "  # type: ignore[union-attr]
-                f"dimensions. Selected embeddings are {vector_size}-dimensional. "
-                f"If you want to recreate the collection, set `force_recreate` "
-                f"parameter to `True`."
-            )
-
-        current_distance_func = (
-            current_vector_config.distance.name.upper()  # type: ignore[union-attr]
-        )
-        if current_distance_func != distance_func:
-            raise QdrantException(
-                f"Existing Qdrant collection is configured for "
-                f"{current_vector_config.distance} "  # type: ignore[union-attr]
-                f"similarity. Please set `distance_func` parameter to "
-                f"`{distance_func}` if you want to reuse it. If you want to "
-                f"recreate the collection, set `force_recreate` parameter to "
-                f"`True`."
-            )
-        qdrant = cls(
-            client=client,
-            collection_name=collection_name,
-            embeddings=embedding,
-            content_payload_key=content_payload_key,
-            metadata_payload_key=metadata_payload_key,
-            distance_strategy=distance_func,
-            vector_name=vector_name,
-            group_id=group_id,
-            group_payload_key=group_payload_key,
-            is_new_collection=is_new_collection
-        )
-        return qdrant
-
-    def _select_relevance_score_fn(self) -> Callable[[float], float]:
-        """
-        The 'correct' relevance function
-        may differ depending on a few things, including:
-        - the distance / similarity metric used by the VectorStore
-        - the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
-        - embedding dimensionality
-        - etc.
-        """
-
-        if self.distance_strategy == "COSINE":
-            return self._cosine_relevance_score_fn
-        elif self.distance_strategy == "DOT":
-            return self._max_inner_product_relevance_score_fn
-        elif self.distance_strategy == "EUCLID":
-            return self._euclidean_relevance_score_fn
-        else:
-            raise ValueError(
-                "Unknown distance strategy, must be cosine, "
-                "max_inner_product, or euclidean"
-            )
-
-    def _similarity_search_with_relevance_scores(
-        self,
-        query: str,
-        k: int = 4,
-        **kwargs: Any,
-    ) -> list[tuple[Document, float]]:
-        """Return docs and relevance scores in the range [0, 1].
-
-        0 is dissimilar, 1 is most similar.
-
-        Args:
-            query: input text
-            k: Number of Documents to return. Defaults to 4.
-            **kwargs: kwargs to be passed to similarity search. Should include:
-                score_threshold: Optional, a floating point value between 0 to 1 to
-                    filter the resulting set of retrieved docs
-
-        Returns:
-            List of Tuples of (doc, similarity_score)
-        """
-        return self.similarity_search_with_score(query, k, **kwargs)
-
-    @classmethod
-    def _build_payloads(
-        cls,
-        texts: Iterable[str],
-        metadatas: Optional[list[dict]],
-        content_payload_key: str,
-        metadata_payload_key: str,
-        group_id: str,
-        group_payload_key: str
-    ) -> list[dict]:
-        payloads = []
-        for i, text in enumerate(texts):
-            if text is None:
-                raise ValueError(
-                    "At least one of the texts is None. Please remove it before "
-                    "calling .from_texts or .add_texts on Qdrant instance."
-                )
-            metadata = metadatas[i] if metadatas is not None else None
-            payloads.append(
-                {
-                    content_payload_key: text,
-                    metadata_payload_key: metadata,
-                    group_payload_key: group_id
-                }
-            )
-
-        return payloads
-
-    @classmethod
-    def _document_from_scored_point(
-        cls,
-        scored_point: Any,
-        content_payload_key: str,
-        metadata_payload_key: str,
-    ) -> Document:
-        return Document(
-            page_content=scored_point.payload.get(content_payload_key),
-            metadata=scored_point.payload.get(metadata_payload_key) or {},
-        )
-
-    @classmethod
-    def _document_from_scored_point_grpc(
-        cls,
-        scored_point: Any,
-        content_payload_key: str,
-        metadata_payload_key: str,
-    ) -> Document:
-        from qdrant_client.conversions.conversion import grpc_to_payload
-
-        payload = grpc_to_payload(scored_point.payload)
-        return Document(
-            page_content=payload[content_payload_key],
-            metadata=payload.get(metadata_payload_key) or {},
-        )
-
-    def _build_condition(self, key: str, value: Any) -> list[rest.FieldCondition]:
-        from qdrant_client.http import models as rest
-
-        out = []
-
-        if isinstance(value, dict):
-            for _key, value in value.items():
-                out.extend(self._build_condition(f"{key}.{_key}", value))
-        elif isinstance(value, list):
-            for _value in value:
-                if isinstance(_value, dict):
-                    out.extend(self._build_condition(f"{key}[]", _value))
-                else:
-                    out.extend(self._build_condition(f"{key}", _value))
-        else:
-            out.append(
-                rest.FieldCondition(
-                    key=key,
-                    match=rest.MatchValue(value=value),
-                )
-            )
-
-        return out
-
-    def _qdrant_filter_from_dict(
-        self, filter: Optional[DictFilter]
-    ) -> Optional[rest.Filter]:
-        from qdrant_client.http import models as rest
-
-        if not filter:
-            return None
-
-        return rest.Filter(
-            must=[
-                condition
-                for key, value in filter.items()
-                for condition in self._build_condition(key, value)
-            ]
-        )
-
-    def _embed_query(self, query: str) -> list[float]:
-        """Embed query text.
-
-        Used to provide backward compatibility with `embedding_function` argument.
-
-        Args:
-            query: Query text.
-
-        Returns:
-            List of floats representing the query embedding.
-        """
-        if self.embeddings is not None:
-            embedding = self.embeddings.embed_query(query)
-        else:
-            if self._embeddings_function is not None:
-                embedding = self._embeddings_function(query)
-            else:
-                raise ValueError("Neither of embeddings or embedding_function is set")
-        return embedding.tolist() if hasattr(embedding, "tolist") else embedding
-
-    def _embed_texts(self, texts: Iterable[str]) -> list[list[float]]:
-        """Embed search texts.
-
-        Used to provide backward compatibility with `embedding_function` argument.
-
-        Args:
-            texts: Iterable of texts to embed.
-
-        Returns:
-            List of floats representing the texts embedding.
-        """
-        if self.embeddings is not None:
-            embeddings = self.embeddings.embed_documents(list(texts))
-            if hasattr(embeddings, "tolist"):
-                embeddings = embeddings.tolist()
-        elif self._embeddings_function is not None:
-            embeddings = []
-            for text in texts:
-                embedding = self._embeddings_function(text)
-                if hasattr(embeddings, "tolist"):
-                    embedding = embedding.tolist()
-                embeddings.append(embedding)
-        else:
-            raise ValueError("Neither of embeddings or embedding_function is set")
-
-        return embeddings
-
-    def _generate_rest_batches(
-        self,
-        texts: Iterable[str],
-        metadatas: Optional[list[dict]] = None,
-        ids: Optional[Sequence[str]] = None,
-        batch_size: int = 64,
-        group_id: Optional[str] = None,
-    ) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]:
-        from qdrant_client.http import models as rest
-
-        texts_iterator = iter(texts)
-        metadatas_iterator = iter(metadatas or [])
-        ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
-        while batch_texts := list(islice(texts_iterator, batch_size)):
-            # Take the corresponding metadata and id for each text in a batch
-            batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
-            batch_ids = list(islice(ids_iterator, batch_size))
-
-            # Generate the embeddings for all the texts in a batch
-            batch_embeddings = self._embed_texts(batch_texts)
-
-            points = [
-                rest.PointStruct(
-                    id=point_id,
-                    vector=vector
-                    if self.vector_name is None
-                    else {self.vector_name: vector},
-                    payload=payload,
-                )
-                for point_id, vector, payload in zip(
-                    batch_ids,
-                    batch_embeddings,
-                    self._build_payloads(
-                        batch_texts,
-                        batch_metadatas,
-                        self.content_payload_key,
-                        self.metadata_payload_key,
-                        self.group_id,
-                        self.group_payload_key
-                    ),
-                )
-            ]
-
-            yield batch_ids, points

+ 0 - 506
api/core/vector_store/vector/weaviate.py

@@ -1,506 +0,0 @@
-"""Wrapper around weaviate vector database."""
-from __future__ import annotations
-
-import datetime
-from collections.abc import Callable, Iterable
-from typing import Any, Optional
-from uuid import uuid4
-
-import numpy as np
-from langchain.docstore.document import Document
-from langchain.embeddings.base import Embeddings
-from langchain.utils import get_from_dict_or_env
-from langchain.vectorstores.base import VectorStore
-from langchain.vectorstores.utils import maximal_marginal_relevance
-
-
-def _default_schema(index_name: str) -> dict:
-    return {
-        "class": index_name,
-        "properties": [
-            {
-                "name": "text",
-                "dataType": ["text"],
-            }
-        ],
-    }
-
-
-def _create_weaviate_client(**kwargs: Any) -> Any:
-    client = kwargs.get("client")
-    if client is not None:
-        return client
-
-    weaviate_url = get_from_dict_or_env(kwargs, "weaviate_url", "WEAVIATE_URL")
-
-    try:
-        # the weaviate api key param should not be mandatory
-        weaviate_api_key = get_from_dict_or_env(
-            kwargs, "weaviate_api_key", "WEAVIATE_API_KEY", None
-        )
-    except ValueError:
-        weaviate_api_key = None
-
-    try:
-        import weaviate
-    except ImportError:
-        raise ValueError(
-            "Could not import weaviate python  package. "
-            "Please install it with `pip install weaviate-client`"
-        )
-
-    auth = (
-        weaviate.auth.AuthApiKey(api_key=weaviate_api_key)
-        if weaviate_api_key is not None
-        else None
-    )
-    client = weaviate.Client(weaviate_url, auth_client_secret=auth)
-
-    return client
-
-
-def _default_score_normalizer(val: float) -> float:
-    return 1 - val
-
-
-def _json_serializable(value: Any) -> Any:
-    if isinstance(value, datetime.datetime):
-        return value.isoformat()
-    return value
-
-
-class Weaviate(VectorStore):
-    """Wrapper around Weaviate vector database.
-
-    To use, you should have the ``weaviate-client`` python package installed.
-
-    Example:
-        .. code-block:: python
-
-            import weaviate
-            from langchain.vectorstores import Weaviate
-            client = weaviate.Client(url=os.environ["WEAVIATE_URL"], ...)
-            weaviate = Weaviate(client, index_name, text_key)
-
-    """
-
-    def __init__(
-        self,
-        client: Any,
-        index_name: str,
-        text_key: str,
-        embedding: Optional[Embeddings] = None,
-        attributes: Optional[list[str]] = None,
-        relevance_score_fn: Optional[
-            Callable[[float], float]
-        ] = _default_score_normalizer,
-        by_text: bool = True,
-    ):
-        """Initialize with Weaviate client."""
-        try:
-            import weaviate
-        except ImportError:
-            raise ValueError(
-                "Could not import weaviate python package. "
-                "Please install it with `pip install weaviate-client`."
-            )
-        if not isinstance(client, weaviate.Client):
-            raise ValueError(
-                f"client should be an instance of weaviate.Client, got {type(client)}"
-            )
-        self._client = client
-        self._index_name = index_name
-        self._embedding = embedding
-        self._text_key = text_key
-        self._query_attrs = [self._text_key]
-        self.relevance_score_fn = relevance_score_fn
-        self._by_text = by_text
-        if attributes is not None:
-            self._query_attrs.extend(attributes)
-
-    @property
-    def embeddings(self) -> Optional[Embeddings]:
-        return self._embedding
-
-    def _select_relevance_score_fn(self) -> Callable[[float], float]:
-        return (
-            self.relevance_score_fn
-            if self.relevance_score_fn
-            else _default_score_normalizer
-        )
-
-    def add_texts(
-        self,
-        texts: Iterable[str],
-        metadatas: Optional[list[dict]] = None,
-        **kwargs: Any,
-    ) -> list[str]:
-        """Upload texts with metadata (properties) to Weaviate."""
-        from weaviate.util import get_valid_uuid
-
-        ids = []
-        embeddings: Optional[list[list[float]]] = None
-        if self._embedding:
-            if not isinstance(texts, list):
-                texts = list(texts)
-            embeddings = self._embedding.embed_documents(texts)
-
-        with self._client.batch as batch:
-            for i, text in enumerate(texts):
-                data_properties = {self._text_key: text}
-                if metadatas is not None:
-                    for key, val in metadatas[i].items():
-                        data_properties[key] = _json_serializable(val)
-
-                # Allow for ids (consistent w/ other methods)
-                # # Or uuids (backwards compatble w/ existing arg)
-                # If the UUID of one of the objects already exists
-                # then the existing object will be replaced by the new object.
-                _id = get_valid_uuid(uuid4())
-                if "uuids" in kwargs:
-                    _id = kwargs["uuids"][i]
-                elif "ids" in kwargs:
-                    _id = kwargs["ids"][i]
-
-                batch.add_data_object(
-                    data_object=data_properties,
-                    class_name=self._index_name,
-                    uuid=_id,
-                    vector=embeddings[i] if embeddings else None,
-                )
-                ids.append(_id)
-        return ids
-
-    def similarity_search(
-        self, query: str, k: int = 4, **kwargs: Any
-    ) -> list[Document]:
-        """Return docs most similar to query.
-
-        Args:
-            query: Text to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-
-        Returns:
-            List of Documents most similar to the query.
-        """
-        if self._by_text:
-            return self.similarity_search_by_text(query, k, **kwargs)
-        else:
-            if self._embedding is None:
-                raise ValueError(
-                    "_embedding cannot be None for similarity_search when "
-                    "_by_text=False"
-                )
-            embedding = self._embedding.embed_query(query)
-            return self.similarity_search_by_vector(embedding, k, **kwargs)
-
-    def similarity_search_by_text(
-        self, query: str, k: int = 4, **kwargs: Any
-    ) -> list[Document]:
-        """Return docs most similar to query.
-
-        Args:
-            query: Text to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-
-        Returns:
-            List of Documents most similar to the query.
-        """
-        content: dict[str, Any] = {"concepts": [query]}
-        if kwargs.get("search_distance"):
-            content["certainty"] = kwargs.get("search_distance")
-        query_obj = self._client.query.get(self._index_name, self._query_attrs)
-        if kwargs.get("where_filter"):
-            query_obj = query_obj.with_where(kwargs.get("where_filter"))
-        if kwargs.get("additional"):
-            query_obj = query_obj.with_additional(kwargs.get("additional"))
-        result = query_obj.with_near_text(content).with_limit(k).do()
-        if "errors" in result:
-            raise ValueError(f"Error during query: {result['errors']}")
-        docs = []
-        for res in result["data"]["Get"][self._index_name]:
-            text = res.pop(self._text_key)
-            docs.append(Document(page_content=text, metadata=res))
-        return docs
-
-    def similarity_search_by_bm25(
-        self, query: str, k: int = 4, **kwargs: Any
-    ) -> list[Document]:
-        """Return docs using BM25F.
-
-        Args:
-            query: Text to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-
-        Returns:
-            List of Documents most similar to the query.
-        """
-        content: dict[str, Any] = {"concepts": [query]}
-        if kwargs.get("search_distance"):
-            content["certainty"] = kwargs.get("search_distance")
-        query_obj = self._client.query.get(self._index_name, self._query_attrs)
-        if kwargs.get("where_filter"):
-            query_obj = query_obj.with_where(kwargs.get("where_filter"))
-        if kwargs.get("additional"):
-            query_obj = query_obj.with_additional(kwargs.get("additional"))
-        properties = ['text']
-        result = query_obj.with_bm25(query=query, properties=properties).with_limit(k).do()
-        if "errors" in result:
-            raise ValueError(f"Error during query: {result['errors']}")
-        docs = []
-        for res in result["data"]["Get"][self._index_name]:
-            text = res.pop(self._text_key)
-            docs.append(Document(page_content=text, metadata=res))
-        return docs
-
-    def similarity_search_by_vector(
-        self, embedding: list[float], k: int = 4, **kwargs: Any
-    ) -> list[Document]:
-        """Look up similar documents by embedding vector in Weaviate."""
-        vector = {"vector": embedding}
-        query_obj = self._client.query.get(self._index_name, self._query_attrs)
-        if kwargs.get("where_filter"):
-            query_obj = query_obj.with_where(kwargs.get("where_filter"))
-        if kwargs.get("additional"):
-            query_obj = query_obj.with_additional(kwargs.get("additional"))
-        result = query_obj.with_near_vector(vector).with_limit(k).do()
-        if "errors" in result:
-            raise ValueError(f"Error during query: {result['errors']}")
-        docs = []
-        for res in result["data"]["Get"][self._index_name]:
-            text = res.pop(self._text_key)
-            docs.append(Document(page_content=text, metadata=res))
-        return docs
-
-    def max_marginal_relevance_search(
-        self,
-        query: str,
-        k: int = 4,
-        fetch_k: int = 20,
-        lambda_mult: float = 0.5,
-        **kwargs: Any,
-    ) -> list[Document]:
-        """Return docs selected using the maximal marginal relevance.
-
-        Maximal marginal relevance optimizes for similarity to query AND diversity
-        among selected documents.
-
-        Args:
-            query: Text to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            fetch_k: Number of Documents to fetch to pass to MMR algorithm.
-            lambda_mult: Number between 0 and 1 that determines the degree
-                        of diversity among the results with 0 corresponding
-                        to maximum diversity and 1 to minimum diversity.
-                        Defaults to 0.5.
-
-        Returns:
-            List of Documents selected by maximal marginal relevance.
-        """
-        if self._embedding is not None:
-            embedding = self._embedding.embed_query(query)
-        else:
-            raise ValueError(
-                "max_marginal_relevance_search requires a suitable Embeddings object"
-            )
-
-        return self.max_marginal_relevance_search_by_vector(
-            embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, **kwargs
-        )
-
-    def max_marginal_relevance_search_by_vector(
-        self,
-        embedding: list[float],
-        k: int = 4,
-        fetch_k: int = 20,
-        lambda_mult: float = 0.5,
-        **kwargs: Any,
-    ) -> list[Document]:
-        """Return docs selected using the maximal marginal relevance.
-
-        Maximal marginal relevance optimizes for similarity to query AND diversity
-        among selected documents.
-
-        Args:
-            embedding: Embedding to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            fetch_k: Number of Documents to fetch to pass to MMR algorithm.
-            lambda_mult: Number between 0 and 1 that determines the degree
-                        of diversity among the results with 0 corresponding
-                        to maximum diversity and 1 to minimum diversity.
-                        Defaults to 0.5.
-
-        Returns:
-            List of Documents selected by maximal marginal relevance.
-        """
-        vector = {"vector": embedding}
-        query_obj = self._client.query.get(self._index_name, self._query_attrs)
-        if kwargs.get("where_filter"):
-            query_obj = query_obj.with_where(kwargs.get("where_filter"))
-        results = (
-            query_obj.with_additional("vector")
-            .with_near_vector(vector)
-            .with_limit(fetch_k)
-            .do()
-        )
-
-        payload = results["data"]["Get"][self._index_name]
-        embeddings = [result["_additional"]["vector"] for result in payload]
-        mmr_selected = maximal_marginal_relevance(
-            np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
-        )
-
-        docs = []
-        for idx in mmr_selected:
-            text = payload[idx].pop(self._text_key)
-            payload[idx].pop("_additional")
-            meta = payload[idx]
-            docs.append(Document(page_content=text, metadata=meta))
-        return docs
-
-    def similarity_search_with_score(
-        self, query: str, k: int = 4, **kwargs: Any
-    ) -> list[tuple[Document, float]]:
-        """
-        Return list of documents most similar to the query
-        text and cosine distance in float for each.
-        Lower score represents more similarity.
-        """
-        if self._embedding is None:
-            raise ValueError(
-                "_embedding cannot be None for similarity_search_with_score"
-            )
-        content: dict[str, Any] = {"concepts": [query]}
-        if kwargs.get("search_distance"):
-            content["certainty"] = kwargs.get("search_distance")
-        query_obj = self._client.query.get(self._index_name, self._query_attrs)
-
-        embedded_query = self._embedding.embed_query(query)
-        if not self._by_text:
-            vector = {"vector": embedded_query}
-            result = (
-                query_obj.with_near_vector(vector)
-                .with_limit(k)
-                .with_additional(["vector", "distance"])
-                .do()
-            )
-        else:
-            result = (
-                query_obj.with_near_text(content)
-                .with_limit(k)
-                .with_additional(["vector", "distance"])
-                .do()
-            )
-
-        if "errors" in result:
-            raise ValueError(f"Error during query: {result['errors']}")
-
-        docs_and_scores = []
-        for res in result["data"]["Get"][self._index_name]:
-            text = res.pop(self._text_key)
-            score = res["_additional"]["distance"]
-            docs_and_scores.append((Document(page_content=text, metadata=res), score))
-        return docs_and_scores
-
-    @classmethod
-    def from_texts(
-        cls: type[Weaviate],
-        texts: list[str],
-        embedding: Embeddings,
-        metadatas: Optional[list[dict]] = None,
-        **kwargs: Any,
-    ) -> Weaviate:
-        """Construct Weaviate wrapper from raw documents.
-
-        This is a user-friendly interface that:
-            1. Embeds documents.
-            2. Creates a new index for the embeddings in the Weaviate instance.
-            3. Adds the documents to the newly created Weaviate index.
-
-        This is intended to be a quick way to get started.
-
-        Example:
-            .. code-block:: python
-
-                from langchain.vectorstores.weaviate import Weaviate
-                from langchain.embeddings import OpenAIEmbeddings
-                embeddings = OpenAIEmbeddings()
-                weaviate = Weaviate.from_texts(
-                    texts,
-                    embeddings,
-                    weaviate_url="http://localhost:8080"
-                )
-        """
-
-        client = _create_weaviate_client(**kwargs)
-
-        from weaviate.util import get_valid_uuid
-
-        index_name = kwargs.get("index_name", f"LangChain_{uuid4().hex}")
-        embeddings = embedding.embed_documents(texts) if embedding else None
-        text_key = "text"
-        schema = _default_schema(index_name)
-        attributes = list(metadatas[0].keys()) if metadatas else None
-
-        # check whether the index already exists
-        if not client.schema.contains(schema):
-            client.schema.create_class(schema)
-
-        with client.batch as batch:
-            for i, text in enumerate(texts):
-                data_properties = {
-                    text_key: text,
-                }
-                if metadatas is not None:
-                    for key in metadatas[i].keys():
-                        data_properties[key] = metadatas[i][key]
-
-                # If the UUID of one of the objects already exists
-                # then the existing objectwill be replaced by the new object.
-                if "uuids" in kwargs:
-                    _id = kwargs["uuids"][i]
-                else:
-                    _id = get_valid_uuid(uuid4())
-
-                # if an embedding strategy is not provided, we let
-                # weaviate create the embedding. Note that this will only
-                # work if weaviate has been installed with a vectorizer module
-                # like text2vec-contextionary for example
-                params = {
-                    "uuid": _id,
-                    "data_object": data_properties,
-                    "class_name": index_name,
-                }
-                if embeddings is not None:
-                    params["vector"] = embeddings[i]
-
-                batch.add_data_object(**params)
-
-            batch.flush()
-
-        relevance_score_fn = kwargs.get("relevance_score_fn")
-        by_text: bool = kwargs.get("by_text", False)
-
-        return cls(
-            client,
-            index_name,
-            text_key,
-            embedding=embedding,
-            attributes=attributes,
-            relevance_score_fn=relevance_score_fn,
-            by_text=by_text,
-        )
-
-    def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> None:
-        """Delete by vector IDs.
-
-        Args:
-            ids: List of ids to delete.
-        """
-
-        if ids is None:
-            raise ValueError("No ids provided to delete.")
-
-        # TODO: Check if this can be done in bulk
-        for id in ids:
-            self._client.data_object.delete(uuid=id)

+ 0 - 38
api/core/vector_store/weaviate_vector_store.py

@@ -1,38 +0,0 @@
-from core.vector_store.vector.weaviate import Weaviate
-
-
-class WeaviateVectorStore(Weaviate):
-    def del_texts(self, where_filter: dict):
-        if not where_filter:
-            raise ValueError('where_filter must not be empty')
-
-        self._client.batch.delete_objects(
-            class_name=self._index_name,
-            where=where_filter,
-            output='minimal'
-        )
-
-    def del_text(self, uuid: str) -> None:
-        self._client.data_object.delete(
-            uuid,
-            class_name=self._index_name
-        )
-
-    def text_exists(self, uuid: str) -> bool:
-        result = self._client.query.get(self._index_name).with_additional(["id"]).with_where({
-            "path": ["doc_id"],
-            "operator": "Equal",
-            "valueText": uuid,
-        }).with_limit(1).do()
-
-        if "errors" in result:
-            raise ValueError(f"Error during query: {result['errors']}")
-
-        entries = result["data"]["Get"][self._index_name]
-        if len(entries) == 0:
-            return False
-
-        return True
-
-    def delete(self):
-        self._client.schema.delete_class(self._index_name)

+ 1 - 1
api/events/event_handlers/clean_when_dataset_deleted.py

@@ -6,4 +6,4 @@ from tasks.clean_dataset_task import clean_dataset_task
 def handle(sender, **kwargs):
 def handle(sender, **kwargs):
     dataset = sender
     dataset = sender
     clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique,
     clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique,
-                             dataset.index_struct, dataset.collection_binding_id)
+                             dataset.index_struct, dataset.collection_binding_id, dataset.doc_form)

+ 2 - 1
api/events/event_handlers/clean_when_document_deleted.py

@@ -6,4 +6,5 @@ from tasks.clean_document_task import clean_document_task
 def handle(sender, **kwargs):
 def handle(sender, **kwargs):
     document_id = sender
     document_id = sender
     dataset_id = kwargs.get('dataset_id')
     dataset_id = kwargs.get('dataset_id')
-    clean_document_task.delay(document_id, dataset_id)
+    doc_form = kwargs.get('doc_form')
+    clean_document_task.delay(document_id, dataset_id, doc_form)

+ 8 - 0
api/models/dataset.py

@@ -94,6 +94,14 @@ class Dataset(db.Model):
         return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \
         return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \
             .filter(Document.dataset_id == self.id).scalar()
             .filter(Document.dataset_id == self.id).scalar()
 
 
+    @property
+    def doc_form(self):
+        document = db.session.query(Document).filter(
+            Document.dataset_id == self.id).first()
+        if document:
+            return document.doc_form
+        return None
+
     @property
     @property
     def retrieval_model_dict(self):
     def retrieval_model_dict(self):
         default_retrieval_model = {
         default_retrieval_model = {

+ 4 - 13
api/schedule/clean_unused_datasets_task.py

@@ -6,7 +6,7 @@ from flask import current_app
 from werkzeug.exceptions import NotFound
 from werkzeug.exceptions import NotFound
 
 
 import app
 import app
-from core.index.index import IndexBuilder
+from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.dataset import Dataset, DatasetQuery, Document
 from models.dataset import Dataset, DatasetQuery, Document
 
 
@@ -41,18 +41,9 @@ def clean_unused_datasets_task():
                 if not documents or len(documents) == 0:
                 if not documents or len(documents) == 0:
                     try:
                     try:
                         # remove index
                         # remove index
-                        vector_index = IndexBuilder.get_index(dataset, 'high_quality')
-                        kw_index = IndexBuilder.get_index(dataset, 'economy')
-                        # delete from vector index
-                        if vector_index:
-                            if dataset.collection_binding_id:
-                                vector_index.delete_by_group_id(dataset.id)
-                            else:
-                                if dataset.collection_binding_id:
-                                    vector_index.delete_by_group_id(dataset.id)
-                                else:
-                                    vector_index.delete()
-                        kw_index.delete()
+                        index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
+                        index_processor.clean(dataset, None)
+
                         # update document
                         # update document
                         update_params = {
                         update_params = {
                             Document.enabled: False
                             Document.enabled: False

+ 21 - 14
api/services/dataset_service.py

@@ -11,10 +11,11 @@ from flask_login import current_user
 from sqlalchemy import func
 from sqlalchemy import func
 
 
 from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
 from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
-from core.index.index import IndexBuilder
 from core.model_manager import ModelManager
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
 from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+from core.rag.datasource.keyword.keyword_factory import Keyword
+from core.rag.models.document import Document as RAGDocument
 from events.dataset_event import dataset_was_deleted
 from events.dataset_event import dataset_was_deleted
 from events.document_event import document_was_deleted
 from events.document_event import document_was_deleted
 from extensions.ext_database import db
 from extensions.ext_database import db
@@ -402,7 +403,7 @@ class DocumentService:
     @staticmethod
     @staticmethod
     def delete_document(document):
     def delete_document(document):
         # trigger document_was_deleted signal
         # trigger document_was_deleted signal
-        document_was_deleted.send(document.id, dataset_id=document.dataset_id)
+        document_was_deleted.send(document.id, dataset_id=document.dataset_id, doc_form=document.doc_form)
 
 
         db.session.delete(document)
         db.session.delete(document)
         db.session.commit()
         db.session.commit()
@@ -1060,7 +1061,7 @@ class SegmentService:
 
 
         # save vector index
         # save vector index
         try:
         try:
-            VectorService.create_segment_vector(args['keywords'], segment_document, dataset)
+            VectorService.create_segments_vector([args['keywords']], [segment_document], dataset)
         except Exception as e:
         except Exception as e:
             logging.exception("create segment index failed")
             logging.exception("create segment index failed")
             segment_document.enabled = False
             segment_document.enabled = False
@@ -1087,6 +1088,7 @@ class SegmentService:
         ).scalar()
         ).scalar()
         pre_segment_data_list = []
         pre_segment_data_list = []
         segment_data_list = []
         segment_data_list = []
+        keywords_list = []
         for segment_item in segments:
         for segment_item in segments:
             content = segment_item['content']
             content = segment_item['content']
             doc_id = str(uuid.uuid4())
             doc_id = str(uuid.uuid4())
@@ -1119,15 +1121,13 @@ class SegmentService:
                 segment_document.answer = segment_item['answer']
                 segment_document.answer = segment_item['answer']
             db.session.add(segment_document)
             db.session.add(segment_document)
             segment_data_list.append(segment_document)
             segment_data_list.append(segment_document)
-            pre_segment_data = {
-                'segment': segment_document,
-                'keywords': segment_item['keywords']
-            }
-            pre_segment_data_list.append(pre_segment_data)
+
+            pre_segment_data_list.append(segment_document)
+            keywords_list.append(segment_item['keywords'])
 
 
         try:
         try:
             # save vector index
             # save vector index
-            VectorService.multi_create_segment_vector(pre_segment_data_list, dataset)
+            VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset)
         except Exception as e:
         except Exception as e:
             logging.exception("create segment index failed")
             logging.exception("create segment index failed")
             for segment_document in segment_data_list:
             for segment_document in segment_data_list:
@@ -1157,11 +1157,18 @@ class SegmentService:
                 db.session.commit()
                 db.session.commit()
                 # update segment index task
                 # update segment index task
                 if args['keywords']:
                 if args['keywords']:
-                    kw_index = IndexBuilder.get_index(dataset, 'economy')
-                    # delete from keyword index
-                    kw_index.delete_by_ids([segment.index_node_id])
-                    # save keyword index
-                    kw_index.update_segment_keywords_index(segment.index_node_id, segment.keywords)
+                    keyword = Keyword(dataset)
+                    keyword.delete_by_ids([segment.index_node_id])
+                    document = RAGDocument(
+                        page_content=segment.content,
+                        metadata={
+                            "doc_id": segment.index_node_id,
+                            "doc_hash": segment.index_node_hash,
+                            "document_id": segment.document_id,
+                            "dataset_id": segment.dataset_id,
+                        }
+                    )
+                    keyword.add_texts([document], keywords_list=[args['keywords']])
             else:
             else:
                 segment_hash = helper.generate_text_hash(content)
                 segment_hash = helper.generate_text_hash(content)
                 tokens = 0
                 tokens = 0

+ 5 - 4
api/services/file_service.py

@@ -9,8 +9,8 @@ from flask_login import current_user
 from werkzeug.datastructures import FileStorage
 from werkzeug.datastructures import FileStorage
 from werkzeug.exceptions import NotFound
 from werkzeug.exceptions import NotFound
 
 
-from core.data_loader.file_extractor import FileExtractor
 from core.file.upload_file_parser import UploadFileParser
 from core.file.upload_file_parser import UploadFileParser
+from core.rag.extractor.extract_processor import ExtractProcessor
 from extensions.ext_database import db
 from extensions.ext_database import db
 from extensions.ext_storage import storage
 from extensions.ext_storage import storage
 from models.account import Account
 from models.account import Account
@@ -32,7 +32,8 @@ class FileService:
     def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile:
     def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile:
         extension = file.filename.split('.')[-1]
         extension = file.filename.split('.')[-1]
         etl_type = current_app.config['ETL_TYPE']
         etl_type = current_app.config['ETL_TYPE']
-        allowed_extensions = UNSTRUSTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS
+        allowed_extensions = UNSTRUSTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS if etl_type == 'Unstructured' \
+            else ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS
         if extension.lower() not in allowed_extensions:
         if extension.lower() not in allowed_extensions:
             raise UnsupportedFileTypeError()
             raise UnsupportedFileTypeError()
         elif only_image and extension.lower() not in IMAGE_EXTENSIONS:
         elif only_image and extension.lower() not in IMAGE_EXTENSIONS:
@@ -136,7 +137,7 @@ class FileService:
         if extension.lower() not in allowed_extensions:
         if extension.lower() not in allowed_extensions:
             raise UnsupportedFileTypeError()
             raise UnsupportedFileTypeError()
 
 
-        text = FileExtractor.load(upload_file, return_text=True)
+        text = ExtractProcessor.load_from_upload_file(upload_file, return_text=True)
         text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
         text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
 
 
         return text
         return text
@@ -164,7 +165,7 @@ class FileService:
         return generator, upload_file.mime_type
         return generator, upload_file.mime_type
 
 
     @staticmethod
     @staticmethod
-    def get_public_image_preview(file_id: str) -> str:
+    def get_public_image_preview(file_id: str) -> tuple[Generator, str]:
         upload_file = db.session.query(UploadFile) \
         upload_file = db.session.query(UploadFile) \
             .filter(UploadFile.id == file_id) \
             .filter(UploadFile.id == file_id) \
             .first()
             .first()

+ 13 - 62
api/services/hit_testing_service.py

@@ -1,21 +1,18 @@
 import logging
 import logging
-import threading
 import time
 import time
 
 
 import numpy as np
 import numpy as np
-from flask import current_app
-from langchain.embeddings.base import Embeddings
-from langchain.schema import Document
 from sklearn.manifold import TSNE
 from sklearn.manifold import TSNE
 
 
 from core.embedding.cached_embedding import CacheEmbedding
 from core.embedding.cached_embedding import CacheEmbedding
 from core.model_manager import ModelManager
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.entities.model_entities import ModelType
-from core.rerank.rerank import RerankRunner
+from core.rag.datasource.entity.embedding import Embeddings
+from core.rag.datasource.retrieval_service import RetrievalService
+from core.rag.models.document import Document
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.account import Account
 from models.account import Account
 from models.dataset import Dataset, DatasetQuery, DocumentSegment
 from models.dataset import Dataset, DatasetQuery, DocumentSegment
-from services.retrieval_service import RetrievalService
 
 
 default_retrieval_model = {
 default_retrieval_model = {
     'search_method': 'semantic_search',
     'search_method': 'semantic_search',
@@ -28,6 +25,7 @@ default_retrieval_model = {
     'score_threshold_enabled': False
     'score_threshold_enabled': False
 }
 }
 
 
+
 class HitTestingService:
 class HitTestingService:
     @classmethod
     @classmethod
     def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict:
     def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict:
@@ -57,61 +55,15 @@ class HitTestingService:
 
 
         embeddings = CacheEmbedding(embedding_model)
         embeddings = CacheEmbedding(embedding_model)
 
 
-        all_documents = []
-        threads = []
-
-        # retrieval_model source with semantic
-        if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
-            embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
-                'flask_app': current_app._get_current_object(),
-                'dataset_id': str(dataset.id),
-                'query': query,
-                'top_k': retrieval_model['top_k'],
-                'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
-                'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,
-                'all_documents': all_documents,
-                'search_method': retrieval_model['search_method'],
-                'embeddings': embeddings
-            })
-            threads.append(embedding_thread)
-            embedding_thread.start()
-
-        # retrieval source with full text
-        if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
-            full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
-                'flask_app': current_app._get_current_object(),
-                'dataset_id': str(dataset.id),
-                'query': query,
-                'search_method': retrieval_model['search_method'],
-                'embeddings': embeddings,
-                'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
-                'top_k': retrieval_model['top_k'],
-                'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,
-                'all_documents': all_documents
-            })
-            threads.append(full_text_index_thread)
-            full_text_index_thread.start()
-
-        for thread in threads:
-            thread.join()
-
-        if retrieval_model['search_method'] == 'hybrid_search':
-            model_manager = ModelManager()
-            rerank_model_instance = model_manager.get_model_instance(
-                tenant_id=dataset.tenant_id,
-                provider=retrieval_model['reranking_model']['reranking_provider_name'],
-                model_type=ModelType.RERANK,
-                model=retrieval_model['reranking_model']['reranking_model_name']
-            )
-
-            rerank_runner = RerankRunner(rerank_model_instance)
-            all_documents = rerank_runner.run(
-                query=query,
-                documents=all_documents,
-                score_threshold=retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
-                top_n=retrieval_model['top_k'],
-                user=f"account-{account.id}"
-            )
+        all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
+                                                  dataset_id=dataset.id,
+                                                  query=query,
+                                                  top_k=retrieval_model['top_k'],
+                                                  score_threshold=retrieval_model['score_threshold']
+                                                  if retrieval_model['score_threshold_enabled'] else None,
+                                                  reranking_model=retrieval_model['reranking_model']
+                                                  if retrieval_model['reranking_enable'] else None
+                                                  )
 
 
         end = time.perf_counter()
         end = time.perf_counter()
         logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
         logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
@@ -203,4 +155,3 @@ class HitTestingService:
 
 
         if not query or len(query) > 250:
         if not query or len(query) > 250:
             raise ValueError('Query is required and cannot exceed 250 characters')
             raise ValueError('Query is required and cannot exceed 250 characters')
-

+ 0 - 119
api/services/retrieval_service.py

@@ -1,119 +0,0 @@
-from typing import Optional
-
-from flask import Flask, current_app
-from langchain.embeddings.base import Embeddings
-
-from core.index.vector_index.vector_index import VectorIndex
-from core.model_manager import ModelManager
-from core.model_runtime.entities.model_entities import ModelType
-from core.model_runtime.errors.invoke import InvokeAuthorizationError
-from core.rerank.rerank import RerankRunner
-from extensions.ext_database import db
-from models.dataset import Dataset
-
-default_retrieval_model = {
-    'search_method': 'semantic_search',
-    'reranking_enable': False,
-    'reranking_model': {
-        'reranking_provider_name': '',
-        'reranking_model_name': ''
-    },
-    'top_k': 2,
-    'score_threshold_enabled': False
-}
-
-
-class RetrievalService:
-
-    @classmethod
-    def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str,
-                         top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
-                         all_documents: list, search_method: str, embeddings: Embeddings):
-        with flask_app.app_context():
-            dataset = db.session.query(Dataset).filter(
-                Dataset.id == dataset_id
-            ).first()
-
-            vector_index = VectorIndex(
-                dataset=dataset,
-                config=current_app.config,
-                embeddings=embeddings
-            )
-
-            documents = vector_index.search(
-                query,
-                search_type='similarity_score_threshold',
-                search_kwargs={
-                    'k': top_k,
-                    'score_threshold': score_threshold,
-                    'filter': {
-                        'group_id': [dataset.id]
-                    }
-                }
-            )
-
-            if documents:
-                if reranking_model and search_method == 'semantic_search':
-                    try:
-                        model_manager = ModelManager()
-                        rerank_model_instance = model_manager.get_model_instance(
-                            tenant_id=dataset.tenant_id,
-                            provider=reranking_model['reranking_provider_name'],
-                            model_type=ModelType.RERANK,
-                            model=reranking_model['reranking_model_name']
-                        )
-                    except InvokeAuthorizationError:
-                        return
-
-                    rerank_runner = RerankRunner(rerank_model_instance)
-                    all_documents.extend(rerank_runner.run(
-                        query=query,
-                        documents=documents,
-                        score_threshold=score_threshold,
-                        top_n=len(documents)
-                    ))
-                else:
-                    all_documents.extend(documents)
-
-    @classmethod
-    def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str,
-                               top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
-                               all_documents: list, search_method: str, embeddings: Embeddings):
-        with flask_app.app_context():
-            dataset = db.session.query(Dataset).filter(
-                Dataset.id == dataset_id
-            ).first()
-
-            vector_index = VectorIndex(
-                dataset=dataset,
-                config=current_app.config,
-                embeddings=embeddings
-            )
-
-            documents = vector_index.search_by_full_text_index(
-                query,
-                search_type='similarity_score_threshold',
-                top_k=top_k
-            )
-            if documents:
-                if reranking_model and search_method == 'full_text_search':
-                    try:
-                        model_manager = ModelManager()
-                        rerank_model_instance = model_manager.get_model_instance(
-                            tenant_id=dataset.tenant_id,
-                            provider=reranking_model['reranking_provider_name'],
-                            model_type=ModelType.RERANK,
-                            model=reranking_model['reranking_model_name']
-                        )
-                    except InvokeAuthorizationError:
-                        return
-
-                    rerank_runner = RerankRunner(rerank_model_instance)
-                    all_documents.extend(rerank_runner.run(
-                        query=query,
-                        documents=documents,
-                        score_threshold=score_threshold,
-                        top_n=len(documents)
-                    ))
-                else:
-                    all_documents.extend(documents)

+ 31 - 54
api/services/vector_service.py

@@ -1,44 +1,18 @@
-
 from typing import Optional
 from typing import Optional
 
 
-from langchain.schema import Document
-
-from core.index.index import IndexBuilder
+from core.rag.datasource.keyword.keyword_factory import Keyword
+from core.rag.datasource.vdb.vector_factory import Vector
+from core.rag.models.document import Document
 from models.dataset import Dataset, DocumentSegment
 from models.dataset import Dataset, DocumentSegment
 
 
 
 
 class VectorService:
 class VectorService:
 
 
     @classmethod
     @classmethod
-    def create_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset):
-        document = Document(
-            page_content=segment.content,
-            metadata={
-                "doc_id": segment.index_node_id,
-                "doc_hash": segment.index_node_hash,
-                "document_id": segment.document_id,
-                "dataset_id": segment.dataset_id,
-            }
-        )
-
-        # save vector index
-        index = IndexBuilder.get_index(dataset, 'high_quality')
-        if index:
-            index.add_texts([document], duplicate_check=True)
-
-        # save keyword index
-        index = IndexBuilder.get_index(dataset, 'economy')
-        if index:
-            if keywords and len(keywords) > 0:
-                index.create_segment_keywords(segment.index_node_id, keywords)
-            else:
-                index.add_texts([document])
-
-    @classmethod
-    def multi_create_segment_vector(cls, pre_segment_data_list: list, dataset: Dataset):
+    def create_segments_vector(cls, keywords_list: Optional[list[list[str]]],
+                               segments: list[DocumentSegment], dataset: Dataset):
         documents = []
         documents = []
-        for pre_segment_data in pre_segment_data_list:
-            segment = pre_segment_data['segment']
+        for segment in segments:
             document = Document(
             document = Document(
                 page_content=segment.content,
                 page_content=segment.content,
                 metadata={
                 metadata={
@@ -49,30 +23,26 @@ class VectorService:
                 }
                 }
             )
             )
             documents.append(document)
             documents.append(document)
-
-        # save vector index
-        index = IndexBuilder.get_index(dataset, 'high_quality')
-        if index:
-            index.add_texts(documents, duplicate_check=True)
+        if dataset.indexing_technique == 'high_quality':
+            # save vector index
+            vector = Vector(
+                dataset=dataset
+            )
+            vector.add_texts(documents, duplicate_check=True)
 
 
         # save keyword index
         # save keyword index
-        keyword_index = IndexBuilder.get_index(dataset, 'economy')
-        if keyword_index:
-            keyword_index.multi_create_segment_keywords(pre_segment_data_list)
+        keyword = Keyword(dataset)
+
+        if keywords_list and len(keywords_list) > 0:
+            keyword.add_texts(documents, keyword_list=keywords_list)
+        else:
+            keyword.add_texts(documents)
 
 
     @classmethod
     @classmethod
     def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset):
     def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset):
         # update segment index task
         # update segment index task
-        vector_index = IndexBuilder.get_index(dataset, 'high_quality')
-        kw_index = IndexBuilder.get_index(dataset, 'economy')
-        # delete from vector index
-        if vector_index:
-            vector_index.delete_by_ids([segment.index_node_id])
 
 
-        # delete from keyword index
-        kw_index.delete_by_ids([segment.index_node_id])
-
-        # add new index
+        # format new index
         document = Document(
         document = Document(
             page_content=segment.content,
             page_content=segment.content,
             metadata={
             metadata={
@@ -82,13 +52,20 @@ class VectorService:
                 "dataset_id": segment.dataset_id,
                 "dataset_id": segment.dataset_id,
             }
             }
         )
         )
+        if dataset.indexing_technique == 'high_quality':
+            # update vector index
+            vector = Vector(
+                dataset=dataset
+            )
+            vector.delete_by_ids([segment.index_node_id])
+            vector.add_texts([document], duplicate_check=True)
 
 
-        # save vector index
-        if vector_index:
-            vector_index.add_texts([document], duplicate_check=True)
+        # update keyword index
+        keyword = Keyword(dataset)
+        keyword.delete_by_ids([segment.index_node_id])
 
 
         # save keyword index
         # save keyword index
         if keywords and len(keywords) > 0:
         if keywords and len(keywords) > 0:
-            kw_index.create_segment_keywords(segment.index_node_id, keywords)
+            keyword.add_texts([document], keywords_list=[keywords])
         else:
         else:
-            kw_index.add_texts([document])
+            keyword.add_texts([document])

+ 5 - 11
api/tasks/add_document_to_index_task.py

@@ -4,10 +4,10 @@ import time
 
 
 import click
 import click
 from celery import shared_task
 from celery import shared_task
-from langchain.schema import Document
 from werkzeug.exceptions import NotFound
 from werkzeug.exceptions import NotFound
 
 
-from core.index.index import IndexBuilder
+from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
+from core.rag.models.document import Document
 from extensions.ext_database import db
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from extensions.ext_redis import redis_client
 from models.dataset import Document as DatasetDocument
 from models.dataset import Document as DatasetDocument
@@ -60,15 +60,9 @@ def add_document_to_index_task(dataset_document_id: str):
         if not dataset:
         if not dataset:
             raise Exception('Document has no dataset')
             raise Exception('Document has no dataset')
 
 
-        # save vector index
-        index = IndexBuilder.get_index(dataset, 'high_quality')
-        if index:
-            index.add_texts(documents)
-
-        # save keyword index
-        index = IndexBuilder.get_index(dataset, 'economy')
-        if index:
-            index.add_texts(documents)
+        index_type = dataset.doc_form
+        index_processor = IndexProcessorFactory(index_type).init_index_processor()
+        index_processor.load(dataset, documents)
 
 
         end_at = time.perf_counter()
         end_at = time.perf_counter()
         logging.info(
         logging.info(

Nem az összes módosított fájl került megjelenítésre, mert túl sok fájl változott