瀏覽代碼

Feat/dify rag (#2528)

Co-authored-by: jyong <jyong@dify.ai>
Jyong 1 年之前
父節點
當前提交
6c4e6bf1d6
共有 100 個文件被更改,包括 3097 次插入5550 次删除
  1. 二進制
      api/celerybeat-schedule.db
  2. 2 1
      api/config.py
  3. 26 6
      api/controllers/console/datasets/data_source.py
  4. 41 33
      api/controllers/console/datasets/datasets.py
  5. 43 29
      api/controllers/console/datasets/datasets_document.py
  6. 0 107
      api/core/data_loader/file_extractor.py
  7. 0 34
      api/core/data_loader/loader/html.py
  8. 0 55
      api/core/data_loader/loader/pdf.py
  9. 1 1
      api/core/docstore/dataset_docstore.py
  10. 7 31
      api/core/features/annotation_reply.py
  11. 0 51
      api/core/index/index.py
  12. 0 305
      api/core/index/vector_index/base.py
  13. 0 165
      api/core/index/vector_index/milvus_vector_index.py
  14. 0 229
      api/core/index/vector_index/qdrant_vector_index.py
  15. 0 90
      api/core/index/vector_index/vector_index.py
  16. 0 179
      api/core/index/vector_index/weaviate_vector_index.py
  17. 134 252
      api/core/indexing_runner.py
  18. 0 0
      api/core/rag/__init__.py
  19. 38 0
      api/core/rag/cleaner/clean_processor.py
  20. 12 0
      api/core/rag/cleaner/cleaner_base.py
  21. 12 0
      api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py
  22. 15 0
      api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py
  23. 12 0
      api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py
  24. 11 0
      api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py
  25. 11 0
      api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py
  26. 0 0
      api/core/rag/data_post_processor/__init__.py
  27. 49 0
      api/core/rag/data_post_processor/data_post_processor.py
  28. 19 0
      api/core/rag/data_post_processor/reorder.py
  29. 0 0
      api/core/rag/datasource/__init__.py
  30. 21 0
      api/core/rag/datasource/entity/embedding.py
  31. 0 0
      api/core/rag/datasource/keyword/__init__.py
  32. 0 0
      api/core/rag/datasource/keyword/jieba/__init__.py
  33. 18 90
      api/core/rag/datasource/keyword/jieba/jieba.py
  34. 1 1
      api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py
  35. 0 0
      api/core/rag/datasource/keyword/jieba/stopwords.py
  36. 5 23
      api/core/rag/datasource/keyword/keyword_base.py
  37. 60 0
      api/core/rag/datasource/keyword/keyword_factory.py
  38. 165 0
      api/core/rag/datasource/retrieval_service.py
  39. 0 0
      api/core/rag/datasource/vdb/__init__.py
  40. 10 0
      api/core/rag/datasource/vdb/field.py
  41. 0 0
      api/core/rag/datasource/vdb/milvus/__init__.py
  42. 214 0
      api/core/rag/datasource/vdb/milvus/milvus_vector.py
  43. 0 0
      api/core/rag/datasource/vdb/qdrant/__init__.py
  44. 360 0
      api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
  45. 62 0
      api/core/rag/datasource/vdb/vector_base.py
  46. 171 0
      api/core/rag/datasource/vdb/vector_factory.py
  47. 0 0
      api/core/rag/datasource/vdb/weaviate/__init__.py
  48. 235 0
      api/core/rag/datasource/vdb/weaviate/weaviate_vector.py
  49. 166 0
      api/core/rag/extractor/blod/blod.py
  50. 71 0
      api/core/rag/extractor/csv_extractor.py
  51. 6 0
      api/core/rag/extractor/entity/datasource_type.py
  52. 36 0
      api/core/rag/extractor/entity/extract_setting.py
  53. 14 9
      api/core/rag/extractor/excel_extractor.py
  54. 139 0
      api/core/rag/extractor/extract_processor.py
  55. 12 0
      api/core/rag/extractor/extractor_base.py
  56. 46 0
      api/core/rag/extractor/helpers.py
  57. 24 20
      api/core/rag/extractor/html_extractor.py
  58. 14 26
      api/core/rag/extractor/markdown_extractor.py
  59. 23 37
      api/core/rag/extractor/notion_extractor.py
  60. 72 0
      api/core/rag/extractor/pdf_extractor.py
  61. 50 0
      api/core/rag/extractor/text_extractor.py
  62. 61 0
      api/core/rag/extractor/unstructured/unstructured_doc_extractor.py
  63. 5 4
      api/core/rag/extractor/unstructured/unstructured_eml_extractor.py
  64. 4 4
      api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py
  65. 4 4
      api/core/rag/extractor/unstructured/unstructured_msg_extractor.py
  66. 8 7
      api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py
  67. 9 7
      api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py
  68. 4 4
      api/core/rag/extractor/unstructured/unstructured_text_extractor.py
  69. 4 4
      api/core/rag/extractor/unstructured/unstructured_xml_extractor.py
  70. 62 0
      api/core/rag/extractor/word_extractor.py
  71. 0 0
      api/core/rag/index_processor/__init__.py
  72. 0 0
      api/core/rag/index_processor/constant/__init__.py
  73. 8 0
      api/core/rag/index_processor/constant/index_type.py
  74. 70 0
      api/core/rag/index_processor/index_processor_base.py
  75. 28 0
      api/core/rag/index_processor/index_processor_factory.py
  76. 0 0
      api/core/rag/index_processor/processor/__init__.py
  77. 92 0
      api/core/rag/index_processor/processor/paragraph_index_processor.py
  78. 161 0
      api/core/rag/index_processor/processor/qa_index_processor.py
  79. 0 0
      api/core/rag/models/__init__.py
  80. 16 0
      api/core/rag/models/document.py
  81. 5 5
      api/core/tool/web_reader_tool.py
  82. 16 71
      api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py
  83. 17 95
      api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py
  84. 5 5
      api/core/tools/utils/web_reader_tool.py
  85. 0 56
      api/core/vector_store/milvus_vector_store.py
  86. 0 76
      api/core/vector_store/qdrant_vector_store.py
  87. 0 852
      api/core/vector_store/vector/milvus.py
  88. 0 1759
      api/core/vector_store/vector/qdrant.py
  89. 0 506
      api/core/vector_store/vector/weaviate.py
  90. 0 38
      api/core/vector_store/weaviate_vector_store.py
  91. 1 1
      api/events/event_handlers/clean_when_dataset_deleted.py
  92. 2 1
      api/events/event_handlers/clean_when_document_deleted.py
  93. 8 0
      api/models/dataset.py
  94. 4 13
      api/schedule/clean_unused_datasets_task.py
  95. 21 14
      api/services/dataset_service.py
  96. 5 4
      api/services/file_service.py
  97. 13 62
      api/services/hit_testing_service.py
  98. 0 119
      api/services/retrieval_service.py
  99. 31 54
      api/services/vector_service.py
  100. 5 11
      api/tasks/add_document_to_index_task.py

二進制
api/celerybeat-schedule.db


+ 2 - 1
api/config.py

@@ -56,6 +56,7 @@ DEFAULTS = {
     'BILLING_ENABLED': 'False',
     'CAN_REPLACE_LOGO': 'False',
     'ETL_TYPE': 'dify',
+    'KEYWORD_STORE': 'jieba',
     'BATCH_UPLOAD_LIMIT': 20
 }
 
@@ -183,7 +184,7 @@ class Config:
         # Currently, only support: qdrant, milvus, zilliz, weaviate
         # ------------------------
         self.VECTOR_STORE = get_env('VECTOR_STORE')
-
+        self.KEYWORD_STORE = get_env('KEYWORD_STORE')
         # qdrant settings
         self.QDRANT_URL = get_env('QDRANT_URL')
         self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')

+ 26 - 6
api/controllers/console/datasets/data_source.py

@@ -9,8 +9,9 @@ from werkzeug.exceptions import NotFound
 from controllers.console import api
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
-from core.data_loader.loader.notion import NotionLoader
 from core.indexing_runner import IndexingRunner
+from core.rag.extractor.entity.extract_setting import ExtractSetting
+from core.rag.extractor.notion_extractor import NotionExtractor
 from extensions.ext_database import db
 from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
 from libs.login import login_required
@@ -173,14 +174,14 @@ class DataSourceNotionApi(Resource):
         if not data_source_binding:
             raise NotFound('Data source binding not found.')
 
-        loader = NotionLoader(
-            notion_access_token=data_source_binding.access_token,
+        extractor = NotionExtractor(
             notion_workspace_id=workspace_id,
             notion_obj_id=page_id,
-            notion_page_type=page_type
+            notion_page_type=page_type,
+            notion_access_token=data_source_binding.access_token
         )
 
-        text_docs = loader.load()
+        text_docs = extractor.extract()
         return {
             'content': "\n".join([doc.page_content for doc in text_docs])
         }, 200
@@ -192,11 +193,30 @@ class DataSourceNotionApi(Resource):
         parser = reqparse.RequestParser()
         parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json')
         parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
+        parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
+        parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
         args = parser.parse_args()
         # validate args
         DocumentService.estimate_args_validate(args)
+        notion_info_list = args['notion_info_list']
+        extract_settings = []
+        for notion_info in notion_info_list:
+            workspace_id = notion_info['workspace_id']
+            for page in notion_info['pages']:
+                extract_setting = ExtractSetting(
+                    datasource_type="notion_import",
+                    notion_info={
+                        "notion_workspace_id": workspace_id,
+                        "notion_obj_id": page['page_id'],
+                        "notion_page_type": page['type']
+                    },
+                    document_model=args['doc_form']
+                )
+                extract_settings.append(extract_setting)
         indexing_runner = IndexingRunner()
-        response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, args['notion_info_list'], args['process_rule'])
+        response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
+                                                     args['process_rule'], args['doc_form'],
+                                                     args['doc_language'])
         return response, 200
 
 

+ 41 - 33
api/controllers/console/datasets/datasets.py

@@ -15,6 +15,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
 from core.indexing_runner import IndexingRunner
 from core.model_runtime.entities.model_entities import ModelType
 from core.provider_manager import ProviderManager
+from core.rag.extractor.entity.extract_setting import ExtractSetting
 from extensions.ext_database import db
 from fields.app_fields import related_app_list
 from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
@@ -178,9 +179,9 @@ class DatasetApi(Resource):
                             location='json', store_missing=False,
                             type=_validate_description_length)
         parser.add_argument('indexing_technique', type=str, location='json',
-                    choices=Dataset.INDEXING_TECHNIQUE_LIST,
-                    nullable=True,
-                    help='Invalid indexing technique.')
+                            choices=Dataset.INDEXING_TECHNIQUE_LIST,
+                            nullable=True,
+                            help='Invalid indexing technique.')
         parser.add_argument('permission', type=str, location='json', choices=(
             'only_me', 'all_team_members'), help='Invalid permission.')
         parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
@@ -258,7 +259,7 @@ class DatasetIndexingEstimateApi(Resource):
         parser = reqparse.RequestParser()
         parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
         parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
-        parser.add_argument('indexing_technique', type=str, required=True, 
+        parser.add_argument('indexing_technique', type=str, required=True,
                             choices=Dataset.INDEXING_TECHNIQUE_LIST,
                             nullable=True, location='json')
         parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
@@ -268,6 +269,7 @@ class DatasetIndexingEstimateApi(Resource):
         args = parser.parse_args()
         # validate args
         DocumentService.estimate_args_validate(args)
+        extract_settings = []
         if args['info_list']['data_source_type'] == 'upload_file':
             file_ids = args['info_list']['file_info_list']['file_ids']
             file_details = db.session.query(UploadFile).filter(
@@ -278,37 +280,44 @@ class DatasetIndexingEstimateApi(Resource):
             if file_details is None:
                 raise NotFound("File not found.")
 
-            indexing_runner = IndexingRunner()
-
-            try:
-                response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
-                                                                  args['process_rule'], args['doc_form'],
-                                                                  args['doc_language'], args['dataset_id'],
-                                                                  args['indexing_technique'])
-            except LLMBadRequestError:
-                raise ProviderNotInitializeError(
-                    "No Embedding Model available. Please configure a valid provider "
-                    "in the Settings -> Model Provider.")
-            except ProviderTokenNotInitError as ex:
-                raise ProviderNotInitializeError(ex.description)
+            if file_details:
+                for file_detail in file_details:
+                    extract_setting = ExtractSetting(
+                        datasource_type="upload_file",
+                        upload_file=file_detail,
+                        document_model=args['doc_form']
+                    )
+                    extract_settings.append(extract_setting)
         elif args['info_list']['data_source_type'] == 'notion_import':
-
-            indexing_runner = IndexingRunner()
-
-            try:
-                response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
-                                                                    args['info_list']['notion_info_list'],
-                                                                    args['process_rule'], args['doc_form'],
-                                                                    args['doc_language'], args['dataset_id'],
-                                                                    args['indexing_technique'])
-            except LLMBadRequestError:
-                raise ProviderNotInitializeError(
-                    "No Embedding Model available. Please configure a valid provider "
-                    "in the Settings -> Model Provider.")
-            except ProviderTokenNotInitError as ex:
-                raise ProviderNotInitializeError(ex.description)
+            notion_info_list = args['info_list']['notion_info_list']
+            for notion_info in notion_info_list:
+                workspace_id = notion_info['workspace_id']
+                for page in notion_info['pages']:
+                    extract_setting = ExtractSetting(
+                        datasource_type="notion_import",
+                        notion_info={
+                            "notion_workspace_id": workspace_id,
+                            "notion_obj_id": page['page_id'],
+                            "notion_page_type": page['type']
+                        },
+                        document_model=args['doc_form']
+                    )
+                    extract_settings.append(extract_setting)
         else:
             raise ValueError('Data source type not support')
+        indexing_runner = IndexingRunner()
+        try:
+            response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
+                                                         args['process_rule'], args['doc_form'],
+                                                         args['doc_language'], args['dataset_id'],
+                                                         args['indexing_technique'])
+        except LLMBadRequestError:
+            raise ProviderNotInitializeError(
+                "No Embedding Model available. Please configure a valid provider "
+                "in the Settings -> Model Provider.")
+        except ProviderTokenNotInitError as ex:
+            raise ProviderNotInitializeError(ex.description)
+
         return response, 200
 
 
@@ -508,4 +517,3 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
 api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
 api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
 api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
-

+ 43 - 29
api/controllers/console/datasets/datasets_document.py

@@ -32,6 +32,7 @@ from core.indexing_runner import IndexingRunner
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.errors.invoke import InvokeAuthorizationError
+from core.rag.extractor.entity.extract_setting import ExtractSetting
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from fields.document_fields import (
@@ -95,7 +96,7 @@ class GetProcessRuleApi(Resource):
         req_data = request.args
 
         document_id = req_data.get('document_id')
-        
+
         # get default rules
         mode = DocumentService.DEFAULT_RULES['mode']
         rules = DocumentService.DEFAULT_RULES['rules']
@@ -362,12 +363,18 @@ class DocumentIndexingEstimateApi(DocumentResource):
                 if not file:
                     raise NotFound('File not found.')
 
+                extract_setting = ExtractSetting(
+                    datasource_type="upload_file",
+                    upload_file=file,
+                    document_model=document.doc_form
+                )
+
                 indexing_runner = IndexingRunner()
 
                 try:
-                    response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file],
-                                                                      data_process_rule_dict, None,
-                                                                      'English', dataset_id)
+                    response = indexing_runner.indexing_estimate(current_user.current_tenant_id, [extract_setting],
+                                                                 data_process_rule_dict, document.doc_form,
+                                                                 'English', dataset_id)
                 except LLMBadRequestError:
                     raise ProviderNotInitializeError(
                         "No Embedding Model available. Please configure a valid provider "
@@ -402,6 +409,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
         data_process_rule = documents[0].dataset_process_rule
         data_process_rule_dict = data_process_rule.to_dict()
         info_list = []
+        extract_settings = []
         for document in documents:
             if document.indexing_status in ['completed', 'error']:
                 raise DocumentAlreadyFinishedError()
@@ -424,42 +432,48 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
                 }
                 info_list.append(notion_info)
 
-        if dataset.data_source_type == 'upload_file':
-            file_details = db.session.query(UploadFile).filter(
-                UploadFile.tenant_id == current_user.current_tenant_id,
-                UploadFile.id.in_(info_list)
-            ).all()
+            if document.data_source_type == 'upload_file':
+                file_id = data_source_info['upload_file_id']
+                file_detail = db.session.query(UploadFile).filter(
+                    UploadFile.tenant_id == current_user.current_tenant_id,
+                    UploadFile.id == file_id
+                ).first()
 
-            if file_details is None:
-                raise NotFound("File not found.")
+                if file_detail is None:
+                    raise NotFound("File not found.")
 
-            indexing_runner = IndexingRunner()
-            try:
-                response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
-                                                                  data_process_rule_dict, None,
-                                                                  'English', dataset_id)
-            except LLMBadRequestError:
-                raise ProviderNotInitializeError(
-                    "No Embedding Model available. Please configure a valid provider "
-                    "in the Settings -> Model Provider.")
-            except ProviderTokenNotInitError as ex:
-                raise ProviderNotInitializeError(ex.description)
-        elif dataset.data_source_type == 'notion_import':
+                extract_setting = ExtractSetting(
+                    datasource_type="upload_file",
+                    upload_file=file_detail,
+                    document_model=document.doc_form
+                )
+                extract_settings.append(extract_setting)
+
+            elif document.data_source_type == 'notion_import':
+                extract_setting = ExtractSetting(
+                    datasource_type="notion_import",
+                    notion_info={
+                        "notion_workspace_id": data_source_info['notion_workspace_id'],
+                        "notion_obj_id": data_source_info['notion_page_id'],
+                        "notion_page_type": data_source_info['type']
+                    },
+                    document_model=document.doc_form
+                )
+                extract_settings.append(extract_setting)
 
+            else:
+                raise ValueError('Data source type not support')
             indexing_runner = IndexingRunner()
             try:
-                response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
-                                                                    info_list,
-                                                                    data_process_rule_dict,
-                                                                    None, 'English', dataset_id)
+                response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
+                                                             data_process_rule_dict, document.doc_form,
+                                                             'English', dataset_id)
             except LLMBadRequestError:
                 raise ProviderNotInitializeError(
                     "No Embedding Model available. Please configure a valid provider "
                     "in the Settings -> Model Provider.")
             except ProviderTokenNotInitError as ex:
                 raise ProviderNotInitializeError(ex.description)
-        else:
-            raise ValueError('Data source type not support')
         return response
 
 

+ 0 - 107
api/core/data_loader/file_extractor.py

@@ -1,107 +0,0 @@
-import tempfile
-from pathlib import Path
-from typing import Optional, Union
-
-import requests
-from flask import current_app
-from langchain.document_loaders import Docx2txtLoader, TextLoader
-from langchain.schema import Document
-
-from core.data_loader.loader.csv_loader import CSVLoader
-from core.data_loader.loader.excel import ExcelLoader
-from core.data_loader.loader.html import HTMLLoader
-from core.data_loader.loader.markdown import MarkdownLoader
-from core.data_loader.loader.pdf import PdfLoader
-from core.data_loader.loader.unstructured.unstructured_eml import UnstructuredEmailLoader
-from core.data_loader.loader.unstructured.unstructured_markdown import UnstructuredMarkdownLoader
-from core.data_loader.loader.unstructured.unstructured_msg import UnstructuredMsgLoader
-from core.data_loader.loader.unstructured.unstructured_ppt import UnstructuredPPTLoader
-from core.data_loader.loader.unstructured.unstructured_pptx import UnstructuredPPTXLoader
-from core.data_loader.loader.unstructured.unstructured_text import UnstructuredTextLoader
-from core.data_loader.loader.unstructured.unstructured_xml import UnstructuredXmlLoader
-from extensions.ext_storage import storage
-from models.model import UploadFile
-
-SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain']
-USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
-
-
-class FileExtractor:
-    @classmethod
-    def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[list[Document], str]:
-        with tempfile.TemporaryDirectory() as temp_dir:
-            suffix = Path(upload_file.key).suffix
-            file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
-            storage.download(upload_file.key, file_path)
-
-            return cls.load_from_file(file_path, return_text, upload_file, is_automatic)
-
-    @classmethod
-    def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]:
-        response = requests.get(url, headers={
-            "User-Agent": USER_AGENT
-        })
-
-        with tempfile.TemporaryDirectory() as temp_dir:
-            suffix = Path(url).suffix
-            file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
-            with open(file_path, 'wb') as file:
-                file.write(response.content)
-
-            return cls.load_from_file(file_path, return_text)
-
-    @classmethod
-    def load_from_file(cls, file_path: str, return_text: bool = False,
-                       upload_file: Optional[UploadFile] = None,
-                       is_automatic: bool = False) -> Union[list[Document], str]:
-        input_file = Path(file_path)
-        delimiter = '\n'
-        file_extension = input_file.suffix.lower()
-        etl_type = current_app.config['ETL_TYPE']
-        unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL']
-        if etl_type == 'Unstructured':
-            if file_extension == '.xlsx':
-                loader = ExcelLoader(file_path)
-            elif file_extension == '.pdf':
-                loader = PdfLoader(file_path, upload_file=upload_file)
-            elif file_extension in ['.md', '.markdown']:
-                loader = UnstructuredMarkdownLoader(file_path, unstructured_api_url) if is_automatic \
-                    else MarkdownLoader(file_path, autodetect_encoding=True)
-            elif file_extension in ['.htm', '.html']:
-                loader = HTMLLoader(file_path)
-            elif file_extension in ['.docx']:
-                loader = Docx2txtLoader(file_path)
-            elif file_extension == '.csv':
-                loader = CSVLoader(file_path, autodetect_encoding=True)
-            elif file_extension == '.msg':
-                loader = UnstructuredMsgLoader(file_path, unstructured_api_url)
-            elif file_extension == '.eml':
-                loader = UnstructuredEmailLoader(file_path, unstructured_api_url)
-            elif file_extension == '.ppt':
-                loader = UnstructuredPPTLoader(file_path, unstructured_api_url)
-            elif file_extension == '.pptx':
-                loader = UnstructuredPPTXLoader(file_path, unstructured_api_url)
-            elif file_extension == '.xml':
-                loader = UnstructuredXmlLoader(file_path, unstructured_api_url)
-            else:
-                # txt
-                loader = UnstructuredTextLoader(file_path, unstructured_api_url) if is_automatic \
-                    else TextLoader(file_path, autodetect_encoding=True)
-        else:
-            if file_extension == '.xlsx':
-                loader = ExcelLoader(file_path)
-            elif file_extension == '.pdf':
-                loader = PdfLoader(file_path, upload_file=upload_file)
-            elif file_extension in ['.md', '.markdown']:
-                loader = MarkdownLoader(file_path, autodetect_encoding=True)
-            elif file_extension in ['.htm', '.html']:
-                loader = HTMLLoader(file_path)
-            elif file_extension in ['.docx']:
-                loader = Docx2txtLoader(file_path)
-            elif file_extension == '.csv':
-                loader = CSVLoader(file_path, autodetect_encoding=True)
-            else:
-                # txt
-                loader = TextLoader(file_path, autodetect_encoding=True)
-
-        return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load()

+ 0 - 34
api/core/data_loader/loader/html.py

@@ -1,34 +0,0 @@
-import logging
-
-from bs4 import BeautifulSoup
-from langchain.document_loaders.base import BaseLoader
-from langchain.schema import Document
-
-logger = logging.getLogger(__name__)
-
-
-class HTMLLoader(BaseLoader):
-    """Load html files.
-
-
-    Args:
-        file_path: Path to the file to load.
-    """
-
-    def __init__(
-        self,
-        file_path: str
-    ):
-        """Initialize with file path."""
-        self._file_path = file_path
-
-    def load(self) -> list[Document]:
-        return [Document(page_content=self._load_as_text())]
-
-    def _load_as_text(self) -> str:
-        with open(self._file_path, "rb") as fp:
-            soup = BeautifulSoup(fp, 'html.parser')
-            text = soup.get_text()
-            text = text.strip() if text else ''
-
-        return text

+ 0 - 55
api/core/data_loader/loader/pdf.py

@@ -1,55 +0,0 @@
-import logging
-from typing import Optional
-
-from langchain.document_loaders import PyPDFium2Loader
-from langchain.document_loaders.base import BaseLoader
-from langchain.schema import Document
-
-from extensions.ext_storage import storage
-from models.model import UploadFile
-
-logger = logging.getLogger(__name__)
-
-
-class PdfLoader(BaseLoader):
-    """Load pdf files.
-
-
-    Args:
-        file_path: Path to the file to load.
-    """
-
-    def __init__(
-        self,
-        file_path: str,
-        upload_file: Optional[UploadFile] = None
-    ):
-        """Initialize with file path."""
-        self._file_path = file_path
-        self._upload_file = upload_file
-
-    def load(self) -> list[Document]:
-        plaintext_file_key = ''
-        plaintext_file_exists = False
-        if self._upload_file:
-            if self._upload_file.hash:
-                plaintext_file_key = 'upload_files/' + self._upload_file.tenant_id + '/' \
-                                     + self._upload_file.hash + '.0625.plaintext'
-                try:
-                    text = storage.load(plaintext_file_key).decode('utf-8')
-                    plaintext_file_exists = True
-                    return [Document(page_content=text)]
-                except FileNotFoundError:
-                    pass
-        documents = PyPDFium2Loader(file_path=self._file_path).load()
-        text_list = []
-        for document in documents:
-            text_list.append(document.page_content)
-        text = "\n\n".join(text_list)
-
-        # save plaintext file for caching
-        if not plaintext_file_exists and plaintext_file_key:
-            storage.save(plaintext_file_key, text.encode('utf-8'))
-
-        return documents
-

+ 1 - 1
api/core/docstore/dataset_docstore.py

@@ -1,12 +1,12 @@
 from collections.abc import Sequence
 from typing import Any, Optional, cast
 
-from langchain.schema import Document
 from sqlalchemy import func
 
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+from core.rag.models.document import Document
 from extensions.ext_database import db
 from models.dataset import Dataset, DocumentSegment
 

+ 7 - 31
api/core/features/annotation_reply.py

@@ -1,13 +1,8 @@
 import logging
 from typing import Optional
 
-from flask import current_app
-
-from core.embedding.cached_embedding import CacheEmbedding
 from core.entities.application_entities import InvokeFrom
-from core.index.vector_index.vector_index import VectorIndex
-from core.model_manager import ModelManager
-from core.model_runtime.entities.model_entities import ModelType
+from core.rag.datasource.vdb.vector_factory import Vector
 from extensions.ext_database import db
 from models.dataset import Dataset
 from models.model import App, AppAnnotationSetting, Message, MessageAnnotation
@@ -45,17 +40,6 @@ class AnnotationReplyFeature:
             embedding_provider_name = collection_binding_detail.provider_name
             embedding_model_name = collection_binding_detail.model_name
 
-            model_manager = ModelManager()
-            model_instance = model_manager.get_model_instance(
-                tenant_id=app_record.tenant_id,
-                provider=embedding_provider_name,
-                model_type=ModelType.TEXT_EMBEDDING,
-                model=embedding_model_name
-            )
-
-            # get embedding model
-            embeddings = CacheEmbedding(model_instance)
-
             dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
                 embedding_provider_name,
                 embedding_model_name,
@@ -71,22 +55,14 @@ class AnnotationReplyFeature:
                 collection_binding_id=dataset_collection_binding.id
             )
 
-            vector_index = VectorIndex(
-                dataset=dataset,
-                config=current_app.config,
-                embeddings=embeddings,
-                attributes=['doc_id', 'annotation_id', 'app_id']
-            )
+            vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
 
-            documents = vector_index.search(
+            documents = vector.search_by_vector(
                 query=query,
-                search_type='similarity_score_threshold',
-                search_kwargs={
-                    'k': 1,
-                    'score_threshold': score_threshold,
-                    'filter': {
-                        'group_id': [dataset.id]
-                    }
+                k=1,
+                score_threshold=score_threshold,
+                filter={
+                    'group_id': [dataset.id]
                 }
             )
 

+ 0 - 51
api/core/index/index.py

@@ -1,51 +0,0 @@
-from flask import current_app
-from langchain.embeddings import OpenAIEmbeddings
-
-from core.embedding.cached_embedding import CacheEmbedding
-from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex
-from core.index.vector_index.vector_index import VectorIndex
-from core.model_manager import ModelManager
-from core.model_runtime.entities.model_entities import ModelType
-from models.dataset import Dataset
-
-
-class IndexBuilder:
-    @classmethod
-    def get_index(cls, dataset: Dataset, indexing_technique: str, ignore_high_quality_check: bool = False):
-        if indexing_technique == "high_quality":
-            if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
-                return None
-
-            model_manager = ModelManager()
-            embedding_model = model_manager.get_model_instance(
-                tenant_id=dataset.tenant_id,
-                model_type=ModelType.TEXT_EMBEDDING,
-                provider=dataset.embedding_model_provider,
-                model=dataset.embedding_model
-            )
-
-            embeddings = CacheEmbedding(embedding_model)
-
-            return VectorIndex(
-                dataset=dataset,
-                config=current_app.config,
-                embeddings=embeddings
-            )
-        elif indexing_technique == "economy":
-            return KeywordTableIndex(
-                dataset=dataset,
-                config=KeywordTableConfig(
-                    max_keywords_per_chunk=10
-                )
-            )
-        else:
-            raise ValueError('Unknown indexing technique')
-
-    @classmethod
-    def get_default_high_quality_index(cls, dataset: Dataset):
-        embeddings = OpenAIEmbeddings(openai_api_key=' ')
-        return VectorIndex(
-            dataset=dataset,
-            config=current_app.config,
-            embeddings=embeddings
-        )

+ 0 - 305
api/core/index/vector_index/base.py

@@ -1,305 +0,0 @@
-import json
-import logging
-from abc import abstractmethod
-from typing import Any, cast
-
-from langchain.embeddings.base import Embeddings
-from langchain.schema import BaseRetriever, Document
-from langchain.vectorstores import VectorStore
-
-from core.index.base import BaseIndex
-from extensions.ext_database import db
-from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
-from models.dataset import Document as DatasetDocument
-
-
-class BaseVectorIndex(BaseIndex):
-
-    def __init__(self, dataset: Dataset, embeddings: Embeddings):
-        super().__init__(dataset)
-        self._embeddings = embeddings
-        self._vector_store = None
-
-    def get_type(self) -> str:
-        raise NotImplementedError
-
-    @abstractmethod
-    def get_index_name(self, dataset: Dataset) -> str:
-        raise NotImplementedError
-
-    @abstractmethod
-    def to_index_struct(self) -> dict:
-        raise NotImplementedError
-
-    @abstractmethod
-    def _get_vector_store(self) -> VectorStore:
-        raise NotImplementedError
-
-    @abstractmethod
-    def _get_vector_store_class(self) -> type:
-        raise NotImplementedError
-
-    @abstractmethod
-    def search_by_full_text_index(
-            self, query: str,
-            **kwargs: Any
-    ) -> list[Document]:
-        raise NotImplementedError
-
-    def search(
-            self, query: str,
-            **kwargs: Any
-    ) -> list[Document]:
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity'
-        search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
-
-        if search_type == 'similarity_score_threshold':
-            score_threshold = search_kwargs.get("score_threshold")
-            if (score_threshold is None) or (not isinstance(score_threshold, float)):
-                search_kwargs['score_threshold'] = .0
-
-            docs_with_similarity = vector_store.similarity_search_with_relevance_scores(
-                query, **search_kwargs
-            )
-
-            docs = []
-            for doc, similarity in docs_with_similarity:
-                doc.metadata['score'] = similarity
-                docs.append(doc)
-
-            return docs
-
-        # similarity k
-        # mmr k, fetch_k, lambda_mult
-        # similarity_score_threshold k
-        return vector_store.as_retriever(
-            search_type=search_type,
-            search_kwargs=search_kwargs
-        ).get_relevant_documents(query)
-
-    def get_retriever(self, **kwargs: Any) -> BaseRetriever:
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        return vector_store.as_retriever(**kwargs)
-
-    def add_texts(self, texts: list[Document], **kwargs):
-        if self._is_origin():
-            self.recreate_dataset(self.dataset)
-
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        if kwargs.get('duplicate_check', False):
-            texts = self._filter_duplicate_texts(texts)
-
-        uuids = self._get_uuids(texts)
-        vector_store.add_documents(texts, uuids=uuids)
-
-    def text_exists(self, id: str) -> bool:
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        return vector_store.text_exists(id)
-
-    def delete_by_ids(self, ids: list[str]) -> None:
-        if self._is_origin():
-            self.recreate_dataset(self.dataset)
-            return
-
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        for node_id in ids:
-            vector_store.del_text(node_id)
-
-    def delete_by_group_id(self, group_id: str) -> None:
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-        if self.dataset.collection_binding_id:
-            vector_store.delete_by_group_id(group_id)
-        else:
-            vector_store.delete()
-
-    def delete(self) -> None:
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        vector_store.delete()
-
-    def _is_origin(self):
-        return False
-
-    def recreate_dataset(self, dataset: Dataset):
-        logging.info(f"Recreating dataset {dataset.id}")
-
-        try:
-            self.delete()
-        except Exception as e:
-            raise e
-
-        dataset_documents = db.session.query(DatasetDocument).filter(
-            DatasetDocument.dataset_id == dataset.id,
-            DatasetDocument.indexing_status == 'completed',
-            DatasetDocument.enabled == True,
-            DatasetDocument.archived == False,
-        ).all()
-
-        documents = []
-        for dataset_document in dataset_documents:
-            segments = db.session.query(DocumentSegment).filter(
-                DocumentSegment.document_id == dataset_document.id,
-                DocumentSegment.status == 'completed',
-                DocumentSegment.enabled == True
-            ).all()
-
-            for segment in segments:
-                document = Document(
-                    page_content=segment.content,
-                    metadata={
-                        "doc_id": segment.index_node_id,
-                        "doc_hash": segment.index_node_hash,
-                        "document_id": segment.document_id,
-                        "dataset_id": segment.dataset_id,
-                    }
-                )
-
-                documents.append(document)
-
-        origin_index_struct = self.dataset.index_struct[:]
-        self.dataset.index_struct = None
-
-        if documents:
-            try:
-                self.create(documents)
-            except Exception as e:
-                self.dataset.index_struct = origin_index_struct
-                raise e
-
-            dataset.index_struct = json.dumps(self.to_index_struct())
-
-        db.session.commit()
-
-        self.dataset = dataset
-        logging.info(f"Dataset {dataset.id} recreate successfully.")
-
-    def create_qdrant_dataset(self, dataset: Dataset):
-        logging.info(f"create_qdrant_dataset {dataset.id}")
-
-        try:
-            self.delete()
-        except Exception as e:
-            raise e
-
-        dataset_documents = db.session.query(DatasetDocument).filter(
-            DatasetDocument.dataset_id == dataset.id,
-            DatasetDocument.indexing_status == 'completed',
-            DatasetDocument.enabled == True,
-            DatasetDocument.archived == False,
-        ).all()
-
-        documents = []
-        for dataset_document in dataset_documents:
-            segments = db.session.query(DocumentSegment).filter(
-                DocumentSegment.document_id == dataset_document.id,
-                DocumentSegment.status == 'completed',
-                DocumentSegment.enabled == True
-            ).all()
-
-            for segment in segments:
-                document = Document(
-                    page_content=segment.content,
-                    metadata={
-                        "doc_id": segment.index_node_id,
-                        "doc_hash": segment.index_node_hash,
-                        "document_id": segment.document_id,
-                        "dataset_id": segment.dataset_id,
-                    }
-                )
-
-                documents.append(document)
-
-        if documents:
-            try:
-                self.create(documents)
-            except Exception as e:
-                raise e
-
-        logging.info(f"Dataset {dataset.id} recreate successfully.")
-
-    def update_qdrant_dataset(self, dataset: Dataset):
-        logging.info(f"update_qdrant_dataset {dataset.id}")
-
-        segment = db.session.query(DocumentSegment).filter(
-            DocumentSegment.dataset_id == dataset.id,
-            DocumentSegment.status == 'completed',
-            DocumentSegment.enabled == True
-        ).first()
-
-        if segment:
-            try:
-                exist = self.text_exists(segment.index_node_id)
-                if exist:
-                    index_struct = {
-                        "type": 'qdrant',
-                        "vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
-                    }
-                    dataset.index_struct = json.dumps(index_struct)
-                    db.session.commit()
-            except Exception as e:
-                raise e
-
-        logging.info(f"Dataset {dataset.id} recreate successfully.")
-
-    def restore_dataset_in_one(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
-        logging.info(f"restore dataset in_one,_dataset {dataset.id}")
-
-        dataset_documents = db.session.query(DatasetDocument).filter(
-            DatasetDocument.dataset_id == dataset.id,
-            DatasetDocument.indexing_status == 'completed',
-            DatasetDocument.enabled == True,
-            DatasetDocument.archived == False,
-        ).all()
-
-        documents = []
-        for dataset_document in dataset_documents:
-            segments = db.session.query(DocumentSegment).filter(
-                DocumentSegment.document_id == dataset_document.id,
-                DocumentSegment.status == 'completed',
-                DocumentSegment.enabled == True
-            ).all()
-
-            for segment in segments:
-                document = Document(
-                    page_content=segment.content,
-                    metadata={
-                        "doc_id": segment.index_node_id,
-                        "doc_hash": segment.index_node_hash,
-                        "document_id": segment.document_id,
-                        "dataset_id": segment.dataset_id,
-                    }
-                )
-
-                documents.append(document)
-
-        if documents:
-            try:
-                self.add_texts(documents)
-            except Exception as e:
-                raise e
-
-        logging.info(f"Dataset {dataset.id} recreate successfully.")
-
-    def delete_original_collection(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
-        logging.info(f"delete original collection: {dataset.id}")
-
-        self.delete()
-
-        dataset.collection_binding_id = dataset_collection_binding.id
-        db.session.add(dataset)
-        db.session.commit()
-
-        logging.info(f"Dataset {dataset.id} recreate successfully.")

+ 0 - 165
api/core/index/vector_index/milvus_vector_index.py

@@ -1,165 +0,0 @@
-from typing import Any, cast
-
-from langchain.embeddings.base import Embeddings
-from langchain.schema import Document
-from langchain.vectorstores import VectorStore
-from pydantic import BaseModel, root_validator
-
-from core.index.base import BaseIndex
-from core.index.vector_index.base import BaseVectorIndex
-from core.vector_store.milvus_vector_store import MilvusVectorStore
-from models.dataset import Dataset
-
-
-class MilvusConfig(BaseModel):
-    host: str
-    port: int
-    user: str
-    password: str
-    secure: bool = False
-    batch_size: int = 100
-
-    @root_validator()
-    def validate_config(cls, values: dict) -> dict:
-        if not values['host']:
-            raise ValueError("config MILVUS_HOST is required")
-        if not values['port']:
-            raise ValueError("config MILVUS_PORT is required")
-        if not values['user']:
-            raise ValueError("config MILVUS_USER is required")
-        if not values['password']:
-            raise ValueError("config MILVUS_PASSWORD is required")
-        return values
-
-    def to_milvus_params(self):
-        return {
-            'host': self.host,
-            'port': self.port,
-            'user': self.user,
-            'password': self.password,
-            'secure': self.secure
-        }
-
-
-class MilvusVectorIndex(BaseVectorIndex):
-    def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings):
-        super().__init__(dataset, embeddings)
-        self._client_config = config
-
-    def get_type(self) -> str:
-        return 'milvus'
-
-    def get_index_name(self, dataset: Dataset) -> str:
-        if self.dataset.index_struct_dict:
-            class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
-            if not class_prefix.endswith('_Node'):
-                # original class_prefix
-                class_prefix += '_Node'
-
-            return class_prefix
-
-        dataset_id = dataset.id
-        return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
-
-    def to_index_struct(self) -> dict:
-        return {
-            "type": self.get_type(),
-            "vector_store": {"class_prefix": self.get_index_name(self.dataset)}
-        }
-
-    def create(self, texts: list[Document], **kwargs) -> BaseIndex:
-        uuids = self._get_uuids(texts)
-        index_params = {
-            'metric_type': 'IP',
-            'index_type': "HNSW",
-            'params': {"M": 8, "efConstruction": 64}
-        }
-        self._vector_store = MilvusVectorStore.from_documents(
-            texts,
-            self._embeddings,
-            collection_name=self.get_index_name(self.dataset),
-            connection_args=self._client_config.to_milvus_params(),
-            index_params=index_params
-        )
-
-        return self
-
-    def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
-        uuids = self._get_uuids(texts)
-        self._vector_store = MilvusVectorStore.from_documents(
-            texts,
-            self._embeddings,
-            collection_name=collection_name,
-            ids=uuids,
-            content_payload_key='page_content'
-        )
-
-        return self
-
-    def _get_vector_store(self) -> VectorStore:
-        """Only for created index."""
-        if self._vector_store:
-            return self._vector_store
-
-        return MilvusVectorStore(
-            collection_name=self.get_index_name(self.dataset),
-            embedding_function=self._embeddings,
-            connection_args=self._client_config.to_milvus_params()
-        )
-
-    def _get_vector_store_class(self) -> type:
-        return MilvusVectorStore
-
-    def delete_by_document_id(self, document_id: str):
-
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-        ids = vector_store.get_ids_by_document_id(document_id)
-        if ids:
-            vector_store.del_texts({
-                'filter': f'id in {ids}'
-            })
-
-    def delete_by_metadata_field(self, key: str, value: str):
-
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-        ids = vector_store.get_ids_by_metadata_field(key, value)
-        if ids:
-            vector_store.del_texts({
-                'filter': f'id in {ids}'
-            })
-
-    def delete_by_ids(self, doc_ids: list[str]) -> None:
-
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-        ids = vector_store.get_ids_by_doc_ids(doc_ids)
-        vector_store.del_texts({
-            'filter': f' id in {ids}'
-        })
-
-    def delete_by_group_id(self, group_id: str) -> None:
-
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        vector_store.delete()
-
-    def delete(self) -> None:
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        from qdrant_client.http import models
-        vector_store.del_texts(models.Filter(
-            must=[
-                models.FieldCondition(
-                    key="group_id",
-                    match=models.MatchValue(value=self.dataset.id),
-                ),
-            ],
-        ))
-
-    def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
-        # milvus/zilliz doesn't support bm25 search
-        return []

+ 0 - 229
api/core/index/vector_index/qdrant_vector_index.py

@@ -1,229 +0,0 @@
-import os
-from typing import Any, Optional, cast
-
-import qdrant_client
-from langchain.embeddings.base import Embeddings
-from langchain.schema import Document
-from langchain.vectorstores import VectorStore
-from pydantic import BaseModel
-from qdrant_client.http.models import HnswConfigDiff
-
-from core.index.base import BaseIndex
-from core.index.vector_index.base import BaseVectorIndex
-from core.vector_store.qdrant_vector_store import QdrantVectorStore
-from extensions.ext_database import db
-from models.dataset import Dataset, DatasetCollectionBinding
-
-
-class QdrantConfig(BaseModel):
-    endpoint: str
-    api_key: Optional[str]
-    timeout: float = 20
-    root_path: Optional[str]
-
-    def to_qdrant_params(self):
-        if self.endpoint and self.endpoint.startswith('path:'):
-            path = self.endpoint.replace('path:', '')
-            if not os.path.isabs(path):
-                path = os.path.join(self.root_path, path)
-
-            return {
-                'path': path
-            }
-        else:
-            return {
-                'url': self.endpoint,
-                'api_key': self.api_key,
-                'timeout': self.timeout
-            }
-
-
-class QdrantVectorIndex(BaseVectorIndex):
-    def __init__(self, dataset: Dataset, config: QdrantConfig, embeddings: Embeddings):
-        super().__init__(dataset, embeddings)
-        self._client_config = config
-
-    def get_type(self) -> str:
-        return 'qdrant'
-
-    def get_index_name(self, dataset: Dataset) -> str:
-        if dataset.collection_binding_id:
-            dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
-                filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
-                one_or_none()
-            if dataset_collection_binding:
-                return dataset_collection_binding.collection_name
-            else:
-                raise ValueError('Dataset Collection Bindings is not exist!')
-        else:
-            if self.dataset.index_struct_dict:
-                class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
-                return class_prefix
-
-            dataset_id = dataset.id
-            return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
-
-    def to_index_struct(self) -> dict:
-        return {
-            "type": self.get_type(),
-            "vector_store": {"class_prefix": self.get_index_name(self.dataset)}
-        }
-
-    def create(self, texts: list[Document], **kwargs) -> BaseIndex:
-        uuids = self._get_uuids(texts)
-        self._vector_store = QdrantVectorStore.from_documents(
-            texts,
-            self._embeddings,
-            collection_name=self.get_index_name(self.dataset),
-            ids=uuids,
-            content_payload_key='page_content',
-            group_id=self.dataset.id,
-            group_payload_key='group_id',
-            hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
-                                       max_indexing_threads=0, on_disk=False),
-            **self._client_config.to_qdrant_params()
-        )
-
-        return self
-
-    def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
-        uuids = self._get_uuids(texts)
-        self._vector_store = QdrantVectorStore.from_documents(
-            texts,
-            self._embeddings,
-            collection_name=collection_name,
-            ids=uuids,
-            content_payload_key='page_content',
-            group_id=self.dataset.id,
-            group_payload_key='group_id',
-            hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
-                                       max_indexing_threads=0, on_disk=False),
-            **self._client_config.to_qdrant_params()
-        )
-
-        return self
-
-    def _get_vector_store(self) -> VectorStore:
-        """Only for created index."""
-        if self._vector_store:
-            return self._vector_store
-        attributes = ['doc_id', 'dataset_id', 'document_id']
-        client = qdrant_client.QdrantClient(
-            **self._client_config.to_qdrant_params()
-        )
-
-        return QdrantVectorStore(
-            client=client,
-            collection_name=self.get_index_name(self.dataset),
-            embeddings=self._embeddings,
-            content_payload_key='page_content',
-            group_id=self.dataset.id,
-            group_payload_key='group_id'
-        )
-
-    def _get_vector_store_class(self) -> type:
-        return QdrantVectorStore
-
-    def delete_by_document_id(self, document_id: str):
-
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        from qdrant_client.http import models
-
-        vector_store.del_texts(models.Filter(
-            must=[
-                models.FieldCondition(
-                    key="metadata.document_id",
-                    match=models.MatchValue(value=document_id),
-                ),
-            ],
-        ))
-
-    def delete_by_metadata_field(self, key: str, value: str):
-
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        from qdrant_client.http import models
-
-        vector_store.del_texts(models.Filter(
-            must=[
-                models.FieldCondition(
-                    key=f"metadata.{key}",
-                    match=models.MatchValue(value=value),
-                ),
-            ],
-        ))
-
-    def delete_by_ids(self, ids: list[str]) -> None:
-
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        from qdrant_client.http import models
-        for node_id in ids:
-            vector_store.del_texts(models.Filter(
-                must=[
-                    models.FieldCondition(
-                        key="metadata.doc_id",
-                        match=models.MatchValue(value=node_id),
-                    ),
-                ],
-            ))
-
-    def delete_by_group_id(self, group_id: str) -> None:
-
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        from qdrant_client.http import models
-        vector_store.del_texts(models.Filter(
-            must=[
-                models.FieldCondition(
-                    key="group_id",
-                    match=models.MatchValue(value=group_id),
-                ),
-            ],
-        ))
-
-    def delete(self) -> None:
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        from qdrant_client.http import models
-        vector_store.del_texts(models.Filter(
-            must=[
-                models.FieldCondition(
-                    key="group_id",
-                    match=models.MatchValue(value=self.dataset.id),
-                ),
-            ],
-        ))
-
-    def _is_origin(self):
-        if self.dataset.index_struct_dict:
-            class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
-            if not class_prefix.endswith('_Node'):
-                # original class_prefix
-                return True
-
-        return False
-
-    def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        from qdrant_client.http import models
-        return vector_store.similarity_search_by_bm25(models.Filter(
-            must=[
-                models.FieldCondition(
-                    key="group_id",
-                    match=models.MatchValue(value=self.dataset.id),
-                ),
-                models.FieldCondition(
-                    key="page_content",
-                    match=models.MatchText(text=query),
-                )
-            ],
-        ), kwargs.get('top_k', 2))

+ 0 - 90
api/core/index/vector_index/vector_index.py

@@ -1,90 +0,0 @@
-import json
-
-from flask import current_app
-from langchain.embeddings.base import Embeddings
-
-from core.index.vector_index.base import BaseVectorIndex
-from extensions.ext_database import db
-from models.dataset import Dataset, Document
-
-
-class VectorIndex:
-    def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings,
-                 attributes: list = None):
-        if attributes is None:
-            attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
-        self._dataset = dataset
-        self._embeddings = embeddings
-        self._vector_index = self._init_vector_index(dataset, config, embeddings, attributes)
-        self._attributes = attributes
-
-    def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings,
-                           attributes: list) -> BaseVectorIndex:
-        vector_type = config.get('VECTOR_STORE')
-
-        if self._dataset.index_struct_dict:
-            vector_type = self._dataset.index_struct_dict['type']
-
-        if not vector_type:
-            raise ValueError("Vector store must be specified.")
-
-        if vector_type == "weaviate":
-            from core.index.vector_index.weaviate_vector_index import WeaviateConfig, WeaviateVectorIndex
-
-            return WeaviateVectorIndex(
-                dataset=dataset,
-                config=WeaviateConfig(
-                    endpoint=config.get('WEAVIATE_ENDPOINT'),
-                    api_key=config.get('WEAVIATE_API_KEY'),
-                    batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
-                ),
-                embeddings=embeddings,
-                attributes=attributes
-            )
-        elif vector_type == "qdrant":
-            from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex
-
-            return QdrantVectorIndex(
-                dataset=dataset,
-                config=QdrantConfig(
-                    endpoint=config.get('QDRANT_URL'),
-                    api_key=config.get('QDRANT_API_KEY'),
-                    root_path=current_app.root_path,
-                    timeout=config.get('QDRANT_CLIENT_TIMEOUT')
-                ),
-                embeddings=embeddings
-            )
-        elif vector_type == "milvus":
-            from core.index.vector_index.milvus_vector_index import MilvusConfig, MilvusVectorIndex
-
-            return MilvusVectorIndex(
-                dataset=dataset,
-                config=MilvusConfig(
-                    host=config.get('MILVUS_HOST'),
-                    port=config.get('MILVUS_PORT'),
-                    user=config.get('MILVUS_USER'),
-                    password=config.get('MILVUS_PASSWORD'),
-                    secure=config.get('MILVUS_SECURE'),
-                ),
-                embeddings=embeddings
-            )
-        else:
-            raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
-
-    def add_texts(self, texts: list[Document], **kwargs):
-        if not self._dataset.index_struct_dict:
-            self._vector_index.create(texts, **kwargs)
-            self._dataset.index_struct = json.dumps(self._vector_index.to_index_struct())
-            db.session.commit()
-            return
-
-        self._vector_index.add_texts(texts, **kwargs)
-
-    def __getattr__(self, name):
-        if self._vector_index is not None:
-            method = getattr(self._vector_index, name)
-            if callable(method):
-                return method
-
-        raise AttributeError(f"'VectorIndex' object has no attribute '{name}'")
-

+ 0 - 179
api/core/index/vector_index/weaviate_vector_index.py

@@ -1,179 +0,0 @@
-from typing import Any, Optional, cast
-
-import requests
-import weaviate
-from langchain.embeddings.base import Embeddings
-from langchain.schema import Document
-from langchain.vectorstores import VectorStore
-from pydantic import BaseModel, root_validator
-
-from core.index.base import BaseIndex
-from core.index.vector_index.base import BaseVectorIndex
-from core.vector_store.weaviate_vector_store import WeaviateVectorStore
-from models.dataset import Dataset
-
-
-class WeaviateConfig(BaseModel):
-    endpoint: str
-    api_key: Optional[str]
-    batch_size: int = 100
-
-    @root_validator()
-    def validate_config(cls, values: dict) -> dict:
-        if not values['endpoint']:
-            raise ValueError("config WEAVIATE_ENDPOINT is required")
-        return values
-
-
-class WeaviateVectorIndex(BaseVectorIndex):
-
-    def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings, attributes: list):
-        super().__init__(dataset, embeddings)
-        self._client = self._init_client(config)
-        self._attributes = attributes
-
-    def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
-        auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
-
-        weaviate.connect.connection.has_grpc = False
-
-        try:
-            client = weaviate.Client(
-                url=config.endpoint,
-                auth_client_secret=auth_config,
-                timeout_config=(5, 60),
-                startup_period=None
-            )
-        except requests.exceptions.ConnectionError:
-            raise ConnectionError("Vector database connection error")
-
-        client.batch.configure(
-            # `batch_size` takes an `int` value to enable auto-batching
-            # (`None` is used for manual batching)
-            batch_size=config.batch_size,
-            # dynamically update the `batch_size` based on import speed
-            dynamic=True,
-            # `timeout_retries` takes an `int` value to retry on time outs
-            timeout_retries=3,
-        )
-
-        return client
-
-    def get_type(self) -> str:
-        return 'weaviate'
-
-    def get_index_name(self, dataset: Dataset) -> str:
-        if self.dataset.index_struct_dict:
-            class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
-            if not class_prefix.endswith('_Node'):
-                # original class_prefix
-                class_prefix += '_Node'
-
-            return class_prefix
-
-        dataset_id = dataset.id
-        return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
-
-    def to_index_struct(self) -> dict:
-        return {
-            "type": self.get_type(),
-            "vector_store": {"class_prefix": self.get_index_name(self.dataset)}
-        }
-
-    def create(self, texts: list[Document], **kwargs) -> BaseIndex:
-        uuids = self._get_uuids(texts)
-        self._vector_store = WeaviateVectorStore.from_documents(
-            texts,
-            self._embeddings,
-            client=self._client,
-            index_name=self.get_index_name(self.dataset),
-            uuids=uuids,
-            by_text=False
-        )
-
-        return self
-
-    def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
-        uuids = self._get_uuids(texts)
-        self._vector_store = WeaviateVectorStore.from_documents(
-            texts,
-            self._embeddings,
-            client=self._client,
-            index_name=self.get_index_name(self.dataset),
-            uuids=uuids,
-            by_text=False
-        )
-
-        return self
-
-
-    def _get_vector_store(self) -> VectorStore:
-        """Only for created index."""
-        if self._vector_store:
-            return self._vector_store
-
-        attributes = self._attributes
-        if self._is_origin():
-            attributes = ['doc_id']
-
-        return WeaviateVectorStore(
-            client=self._client,
-            index_name=self.get_index_name(self.dataset),
-            text_key='text',
-            embedding=self._embeddings,
-            attributes=attributes,
-            by_text=False
-        )
-
-    def _get_vector_store_class(self) -> type:
-        return WeaviateVectorStore
-
-    def delete_by_document_id(self, document_id: str):
-        if self._is_origin():
-            self.recreate_dataset(self.dataset)
-            return
-
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        vector_store.del_texts({
-            "operator": "Equal",
-            "path": ["document_id"],
-            "valueText": document_id
-        })
-
-    def delete_by_metadata_field(self, key: str, value: str):
-
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        vector_store.del_texts({
-            "operator": "Equal",
-            "path": [key],
-            "valueText": value
-        })
-
-    def delete_by_group_id(self, group_id: str):
-        if self._is_origin():
-            self.recreate_dataset(self.dataset)
-            return
-
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        vector_store.delete()
-
-    def _is_origin(self):
-        if self.dataset.index_struct_dict:
-            class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
-            if not class_prefix.endswith('_Node'):
-                # original class_prefix
-                return True
-
-        return False
-
-    def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
-        vector_store = self._get_vector_store()
-        vector_store = cast(self._get_vector_store_class(), vector_store)
-        return vector_store.similarity_search_by_bm25(query, kwargs.get('top_k', 2), **kwargs)
-

+ 134 - 252
api/core/indexing_runner.py

@@ -9,20 +9,20 @@ from typing import Optional, cast
 
 from flask import Flask, current_app
 from flask_login import current_user
-from langchain.schema import Document
 from langchain.text_splitter import TextSplitter
 from sqlalchemy.orm.exc import ObjectDeletedError
 
-from core.data_loader.file_extractor import FileExtractor
-from core.data_loader.loader.notion import NotionLoader
 from core.docstore.dataset_docstore import DatasetDocumentStore
 from core.errors.error import ProviderTokenNotInitError
 from core.generator.llm_generator import LLMGenerator
-from core.index.index import IndexBuilder
 from core.model_manager import ModelInstance, ModelManager
 from core.model_runtime.entities.model_entities import ModelType, PriceType
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+from core.rag.extractor.entity.extract_setting import ExtractSetting
+from core.rag.index_processor.index_processor_base import BaseIndexProcessor
+from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
+from core.rag.models.document import Document
 from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
@@ -31,7 +31,6 @@ from libs import helper
 from models.dataset import Dataset, DatasetProcessRule, DocumentSegment
 from models.dataset import Document as DatasetDocument
 from models.model import UploadFile
-from models.source import DataSourceBinding
 from services.feature_service import FeatureService
 
 
@@ -57,38 +56,19 @@ class IndexingRunner:
                 processing_rule = db.session.query(DatasetProcessRule). \
                     filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
                     first()
-
-                # load file
-                text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
-
-                # get embedding model instance
-                embedding_model_instance = None
-                if dataset.indexing_technique == 'high_quality':
-                    if dataset.embedding_model_provider:
-                        embedding_model_instance = self.model_manager.get_model_instance(
-                            tenant_id=dataset.tenant_id,
-                            provider=dataset.embedding_model_provider,
-                            model_type=ModelType.TEXT_EMBEDDING,
-                            model=dataset.embedding_model
-                        )
-                    else:
-                        embedding_model_instance = self.model_manager.get_default_model_instance(
-                            tenant_id=dataset.tenant_id,
-                            model_type=ModelType.TEXT_EMBEDDING,
-                        )
-
-                # get splitter
-                splitter = self._get_splitter(processing_rule, embedding_model_instance)
-
-                # split to documents
-                documents = self._step_split(
-                    text_docs=text_docs,
-                    splitter=splitter,
-                    dataset=dataset,
-                    dataset_document=dataset_document,
-                    processing_rule=processing_rule
-                )
-                self._build_index(
+                index_type = dataset_document.doc_form
+                index_processor = IndexProcessorFactory(index_type).init_index_processor()
+                # extract
+                text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
+
+                # transform
+                documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict())
+                # save segment
+                self._load_segments(dataset, dataset_document, documents)
+
+                # load
+                self._load(
+                    index_processor=index_processor,
                     dataset=dataset,
                     dataset_document=dataset_document,
                     documents=documents
@@ -134,39 +114,19 @@ class IndexingRunner:
                 filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
                 first()
 
-            # load file
-            text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
-
-            # get embedding model instance
-            embedding_model_instance = None
-            if dataset.indexing_technique == 'high_quality':
-                if dataset.embedding_model_provider:
-                    embedding_model_instance = self.model_manager.get_model_instance(
-                        tenant_id=dataset.tenant_id,
-                        provider=dataset.embedding_model_provider,
-                        model_type=ModelType.TEXT_EMBEDDING,
-                        model=dataset.embedding_model
-                    )
-                else:
-                    embedding_model_instance = self.model_manager.get_default_model_instance(
-                        tenant_id=dataset.tenant_id,
-                        model_type=ModelType.TEXT_EMBEDDING,
-                    )
-
-            # get splitter
-            splitter = self._get_splitter(processing_rule, embedding_model_instance)
+            index_type = dataset_document.doc_form
+            index_processor = IndexProcessorFactory(index_type).init_index_processor()
+            # extract
+            text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
 
-            # split to documents
-            documents = self._step_split(
-                text_docs=text_docs,
-                splitter=splitter,
-                dataset=dataset,
-                dataset_document=dataset_document,
-                processing_rule=processing_rule
-            )
+            # transform
+            documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict())
+            # save segment
+            self._load_segments(dataset, dataset_document, documents)
 
-            # build index
-            self._build_index(
+            # load
+            self._load(
+                index_processor=index_processor,
                 dataset=dataset,
                 dataset_document=dataset_document,
                 documents=documents
@@ -220,7 +180,15 @@ class IndexingRunner:
                         documents.append(document)
 
             # build index
-            self._build_index(
+            # get the process rule
+            processing_rule = db.session.query(DatasetProcessRule). \
+                filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
+                first()
+
+            index_type = dataset_document.doc_form
+            index_processor = IndexProcessorFactory(index_type, processing_rule.to_dict()).init_index_processor()
+            self._load(
+                index_processor=index_processor,
                 dataset=dataset,
                 dataset_document=dataset_document,
                 documents=documents
@@ -239,16 +207,16 @@ class IndexingRunner:
             dataset_document.stopped_at = datetime.datetime.utcnow()
             db.session.commit()
 
-    def file_indexing_estimate(self, tenant_id: str, file_details: list[UploadFile], tmp_processing_rule: dict,
-                               doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
-                               indexing_technique: str = 'economy') -> dict:
+    def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict,
+                          doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
+                          indexing_technique: str = 'economy') -> dict:
         """
         Estimate the indexing for the document.
         """
         # check document limit
         features = FeatureService.get_features(tenant_id)
         if features.billing.enabled:
-            count = len(file_details)
+            count = len(extract_settings)
             batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
             if count > batch_upload_limit:
                 raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
@@ -284,16 +252,18 @@ class IndexingRunner:
         total_segments = 0
         total_price = 0
         currency = 'USD'
-        for file_detail in file_details:
-
+        index_type = doc_form
+        index_processor = IndexProcessorFactory(index_type).init_index_processor()
+        all_text_docs = []
+        for extract_setting in extract_settings:
+            # extract
+            text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
+            all_text_docs.extend(text_docs)
             processing_rule = DatasetProcessRule(
                 mode=tmp_processing_rule["mode"],
                 rules=json.dumps(tmp_processing_rule["rules"])
             )
 
-            # load data from file
-            text_docs = FileExtractor.load(file_detail, is_automatic=processing_rule.mode == 'automatic')
-
             # get splitter
             splitter = self._get_splitter(processing_rule, embedding_model_instance)
 
@@ -305,7 +275,6 @@ class IndexingRunner:
             )
 
             total_segments += len(documents)
-
             for document in documents:
                 if len(preview_texts) < 5:
                     preview_texts.append(document.page_content)
@@ -364,154 +333,8 @@ class IndexingRunner:
             "preview": preview_texts
         }
 
-    def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict,
-                                 doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
-                                 indexing_technique: str = 'economy') -> dict:
-        """
-        Estimate the indexing for the document.
-        """
-        # check document limit
-        features = FeatureService.get_features(tenant_id)
-        if features.billing.enabled:
-            count = len(notion_info_list)
-            batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
-            if count > batch_upload_limit:
-                raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
-
-        embedding_model_instance = None
-        if dataset_id:
-            dataset = Dataset.query.filter_by(
-                id=dataset_id
-            ).first()
-            if not dataset:
-                raise ValueError('Dataset not found.')
-            if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
-                if dataset.embedding_model_provider:
-                    embedding_model_instance = self.model_manager.get_model_instance(
-                        tenant_id=tenant_id,
-                        provider=dataset.embedding_model_provider,
-                        model_type=ModelType.TEXT_EMBEDDING,
-                        model=dataset.embedding_model
-                    )
-                else:
-                    embedding_model_instance = self.model_manager.get_default_model_instance(
-                        tenant_id=tenant_id,
-                        model_type=ModelType.TEXT_EMBEDDING,
-                    )
-        else:
-            if indexing_technique == 'high_quality':
-                embedding_model_instance = self.model_manager.get_default_model_instance(
-                    tenant_id=tenant_id,
-                    model_type=ModelType.TEXT_EMBEDDING
-                )
-        # load data from notion
-        tokens = 0
-        preview_texts = []
-        total_segments = 0
-        total_price = 0
-        currency = 'USD'
-        for notion_info in notion_info_list:
-            workspace_id = notion_info['workspace_id']
-            data_source_binding = DataSourceBinding.query.filter(
-                db.and_(
-                    DataSourceBinding.tenant_id == current_user.current_tenant_id,
-                    DataSourceBinding.provider == 'notion',
-                    DataSourceBinding.disabled == False,
-                    DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
-                )
-            ).first()
-            if not data_source_binding:
-                raise ValueError('Data source binding not found.')
-
-            for page in notion_info['pages']:
-                loader = NotionLoader(
-                    notion_access_token=data_source_binding.access_token,
-                    notion_workspace_id=workspace_id,
-                    notion_obj_id=page['page_id'],
-                    notion_page_type=page['type']
-                )
-                documents = loader.load()
-
-                processing_rule = DatasetProcessRule(
-                    mode=tmp_processing_rule["mode"],
-                    rules=json.dumps(tmp_processing_rule["rules"])
-                )
-
-                # get splitter
-                splitter = self._get_splitter(processing_rule, embedding_model_instance)
-
-                # split to documents
-                documents = self._split_to_documents_for_estimate(
-                    text_docs=documents,
-                    splitter=splitter,
-                    processing_rule=processing_rule
-                )
-                total_segments += len(documents)
-
-                embedding_model_type_instance = None
-                if embedding_model_instance:
-                    embedding_model_type_instance = embedding_model_instance.model_type_instance
-                    embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
-
-                for document in documents:
-                    if len(preview_texts) < 5:
-                        preview_texts.append(document.page_content)
-                    if indexing_technique == 'high_quality' and embedding_model_type_instance:
-                        tokens += embedding_model_type_instance.get_num_tokens(
-                            model=embedding_model_instance.model,
-                            credentials=embedding_model_instance.credentials,
-                            texts=[document.page_content]
-                        )
-
-        if doc_form and doc_form == 'qa_model':
-            model_instance = self.model_manager.get_default_model_instance(
-                tenant_id=tenant_id,
-                model_type=ModelType.LLM
-            )
-
-            model_type_instance = model_instance.model_type_instance
-            model_type_instance = cast(LargeLanguageModel, model_type_instance)
-            if len(preview_texts) > 0:
-                # qa model document
-                response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
-                                                             doc_language)
-                document_qa_list = self.format_split_text(response)
-
-                price_info = model_type_instance.get_price(
-                    model=model_instance.model,
-                    credentials=model_instance.credentials,
-                    price_type=PriceType.INPUT,
-                    tokens=total_segments * 2000,
-                )
-
-                return {
-                    "total_segments": total_segments * 20,
-                    "tokens": total_segments * 2000,
-                    "total_price": '{:f}'.format(price_info.total_amount),
-                    "currency": price_info.currency,
-                    "qa_preview": document_qa_list,
-                    "preview": preview_texts
-                }
-        if embedding_model_instance:
-            embedding_model_type_instance = embedding_model_instance.model_type_instance
-            embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
-            embedding_price_info = embedding_model_type_instance.get_price(
-                model=embedding_model_instance.model,
-                credentials=embedding_model_instance.credentials,
-                price_type=PriceType.INPUT,
-                tokens=tokens
-            )
-            total_price = '{:f}'.format(embedding_price_info.total_amount)
-            currency = embedding_price_info.currency
-        return {
-            "total_segments": total_segments,
-            "tokens": tokens,
-            "total_price": total_price,
-            "currency": currency,
-            "preview": preview_texts
-        }
-
-    def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> list[Document]:
+    def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \
+            -> list[Document]:
         # load file
         if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
             return []
@@ -527,11 +350,27 @@ class IndexingRunner:
                 one_or_none()
 
             if file_detail:
-                text_docs = FileExtractor.load(file_detail, is_automatic=automatic)
+                extract_setting = ExtractSetting(
+                    datasource_type="upload_file",
+                    upload_file=file_detail,
+                    document_model=dataset_document.doc_form
+                )
+                text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
         elif dataset_document.data_source_type == 'notion_import':
-            loader = NotionLoader.from_document(dataset_document)
-            text_docs = loader.load()
-
+            if (not data_source_info or 'notion_workspace_id' not in data_source_info
+                    or 'notion_page_id' not in data_source_info):
+                raise ValueError("no notion import info found")
+            extract_setting = ExtractSetting(
+                datasource_type="notion_import",
+                notion_info={
+                    "notion_workspace_id": data_source_info['notion_workspace_id'],
+                    "notion_obj_id": data_source_info['notion_page_id'],
+                    "notion_page_type": data_source_info['notion_page_type'],
+                    "document": dataset_document
+                },
+                document_model=dataset_document.doc_form
+            )
+            text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
         # update document status to splitting
         self._update_document_index_status(
             document_id=dataset_document.id,
@@ -545,8 +384,6 @@ class IndexingRunner:
         # replace doc id to document model id
         text_docs = cast(list[Document], text_docs)
         for text_doc in text_docs:
-            # remove invalid symbol
-            text_doc.page_content = self.filter_string(text_doc.page_content)
             text_doc.metadata['document_id'] = dataset_document.id
             text_doc.metadata['dataset_id'] = dataset_document.dataset_id
 
@@ -787,12 +624,12 @@ class IndexingRunner:
             for q, a in matches if q and a
         ]
 
-    def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: list[Document]) -> None:
+    def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset,
+              dataset_document: DatasetDocument, documents: list[Document]) -> None:
         """
-        Build the index for the document.
+        insert index and update document/segment status to completed
         """
-        vector_index = IndexBuilder.get_index(dataset, 'high_quality')
-        keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
+
         embedding_model_instance = None
         if dataset.indexing_technique == 'high_quality':
             embedding_model_instance = self.model_manager.get_model_instance(
@@ -825,13 +662,8 @@ class IndexingRunner:
                     )
                     for document in chunk_documents
                 )
-
-            # save vector index
-            if vector_index:
-                vector_index.add_texts(chunk_documents)
-
-            # save keyword index
-            keyword_table_index.add_texts(chunk_documents)
+            # load index
+            index_processor.load(dataset, chunk_documents)
 
             document_ids = [document.metadata['doc_id'] for document in chunk_documents]
             db.session.query(DocumentSegment).filter(
@@ -911,14 +743,64 @@ class IndexingRunner:
             )
             documents.append(document)
         # save vector index
-        index = IndexBuilder.get_index(dataset, 'high_quality')
-        if index:
-            index.add_texts(documents, duplicate_check=True)
-
-        # save keyword index
-        index = IndexBuilder.get_index(dataset, 'economy')
-        if index:
-            index.add_texts(documents)
+        index_type = dataset.doc_form
+        index_processor = IndexProcessorFactory(index_type).init_index_processor()
+        index_processor.load(dataset, documents)
+
+    def _transform(self, index_processor: BaseIndexProcessor, dataset: Dataset,
+                   text_docs: list[Document], process_rule: dict) -> list[Document]:
+        # get embedding model instance
+        embedding_model_instance = None
+        if dataset.indexing_technique == 'high_quality':
+            if dataset.embedding_model_provider:
+                embedding_model_instance = self.model_manager.get_model_instance(
+                    tenant_id=dataset.tenant_id,
+                    provider=dataset.embedding_model_provider,
+                    model_type=ModelType.TEXT_EMBEDDING,
+                    model=dataset.embedding_model
+                )
+            else:
+                embedding_model_instance = self.model_manager.get_default_model_instance(
+                    tenant_id=dataset.tenant_id,
+                    model_type=ModelType.TEXT_EMBEDDING,
+                )
+
+        documents = index_processor.transform(text_docs, embedding_model_instance=embedding_model_instance,
+                                              process_rule=process_rule)
+
+        return documents
+
+    def _load_segments(self, dataset, dataset_document, documents):
+        # save node to document segment
+        doc_store = DatasetDocumentStore(
+            dataset=dataset,
+            user_id=dataset_document.created_by,
+            document_id=dataset_document.id
+        )
+
+        # add document segments
+        doc_store.add_documents(documents)
+
+        # update document status to indexing
+        cur_time = datetime.datetime.utcnow()
+        self._update_document_index_status(
+            document_id=dataset_document.id,
+            after_indexing_status="indexing",
+            extra_update_params={
+                DatasetDocument.cleaning_completed_at: cur_time,
+                DatasetDocument.splitting_completed_at: cur_time,
+            }
+        )
+
+        # update segment status to indexing
+        self._update_segments_by_document(
+            dataset_document_id=dataset_document.id,
+            update_params={
+                DocumentSegment.status: "indexing",
+                DocumentSegment.indexing_at: datetime.datetime.utcnow()
+            }
+        )
+        pass
 
 
 class DocumentIsPausedException(Exception):

+ 0 - 0
api/core/rag/__init__.py


+ 38 - 0
api/core/rag/cleaner/clean_processor.py

@@ -0,0 +1,38 @@
+import re
+
+
+class CleanProcessor:
+
+    @classmethod
+    def clean(cls, text: str, process_rule: dict) -> str:
+        # default clean
+        # remove invalid symbol
+        text = re.sub(r'<\|', '<', text)
+        text = re.sub(r'\|>', '>', text)
+        text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text)
+        # Unicode  U+FFFE
+        text = re.sub('\uFFFE', '', text)
+
+        rules = process_rule['rules'] if process_rule else None
+        if 'pre_processing_rules' in rules:
+            pre_processing_rules = rules["pre_processing_rules"]
+            for pre_processing_rule in pre_processing_rules:
+                if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True:
+                    # Remove extra spaces
+                    pattern = r'\n{3,}'
+                    text = re.sub(pattern, '\n\n', text)
+                    pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}'
+                    text = re.sub(pattern, ' ', text)
+                elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True:
+                    # Remove email
+                    pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)'
+                    text = re.sub(pattern, '', text)
+
+                    # Remove URL
+                    pattern = r'https?://[^\s]+'
+                    text = re.sub(pattern, '', text)
+        return text
+
+    def filter_string(self, text):
+
+        return text

+ 12 - 0
api/core/rag/cleaner/cleaner_base.py

@@ -0,0 +1,12 @@
+"""Abstract interface for document cleaner implementations."""
+from abc import ABC, abstractmethod
+
+
+class BaseCleaner(ABC):
+    """Interface for clean chunk content.
+    """
+
+    @abstractmethod
+    def clean(self, content: str):
+        raise NotImplementedError
+

+ 12 - 0
api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py

@@ -0,0 +1,12 @@
+"""Abstract interface for document clean implementations."""
+from core.rag.cleaner.cleaner_base import BaseCleaner
+
+
+class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
+
+    def clean(self, content) -> str:
+        """clean document content."""
+        from unstructured.cleaners.core import clean_extra_whitespace
+
+        # Returns "ITEM 1A: RISK FACTORS"
+        return clean_extra_whitespace(content)

+ 15 - 0
api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py

@@ -0,0 +1,15 @@
+"""Abstract interface for document clean implementations."""
+from core.rag.cleaner.cleaner_base import BaseCleaner
+
+
+class UnstructuredGroupBrokenParagraphsCleaner(BaseCleaner):
+
+    def clean(self, content) -> str:
+        """clean document content."""
+        import re
+
+        from unstructured.cleaners.core import group_broken_paragraphs
+
+        para_split_re = re.compile(r"(\s*\n\s*){3}")
+
+        return group_broken_paragraphs(content, paragraph_split=para_split_re)

+ 12 - 0
api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py

@@ -0,0 +1,12 @@
+"""Abstract interface for document clean implementations."""
+from core.rag.cleaner.cleaner_base import BaseCleaner
+
+
+class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
+
+    def clean(self, content) -> str:
+        """clean document content."""
+        from unstructured.cleaners.core import clean_non_ascii_chars
+
+        # Returns "This text containsnon-ascii characters!"
+        return clean_non_ascii_chars(content)

+ 11 - 0
api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py

@@ -0,0 +1,11 @@
+"""Abstract interface for document clean implementations."""
+from core.rag.cleaner.cleaner_base import BaseCleaner
+
+
+class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
+
+    def clean(self, content) -> str:
+        """Replaces unicode quote characters, such as the \x91 character in a string."""
+
+        from unstructured.cleaners.core import replace_unicode_quotes
+        return replace_unicode_quotes(content)

+ 11 - 0
api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py

@@ -0,0 +1,11 @@
+"""Abstract interface for document clean implementations."""
+from core.rag.cleaner.cleaner_base import BaseCleaner
+
+
+class UnstructuredTranslateTextCleaner(BaseCleaner):
+
+    def clean(self, content) -> str:
+        """clean document content."""
+        from unstructured.cleaners.translate import translate_text
+
+        return translate_text(content)

+ 0 - 0
api/core/rag/data_post_processor/__init__.py


+ 49 - 0
api/core/rag/data_post_processor/data_post_processor.py

@@ -0,0 +1,49 @@
+from typing import Optional
+
+from core.model_manager import ModelManager
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.invoke import InvokeAuthorizationError
+from core.rag.data_post_processor.reorder import ReorderRunner
+from core.rag.models.document import Document
+from core.rerank.rerank import RerankRunner
+
+
+class DataPostProcessor:
+    """Interface for data post-processing document.
+    """
+
+    def __init__(self, tenant_id: str, reranking_model: dict, reorder_enabled: bool = False):
+        self.rerank_runner = self._get_rerank_runner(reranking_model, tenant_id)
+        self.reorder_runner = self._get_reorder_runner(reorder_enabled)
+
+    def invoke(self, query: str, documents: list[Document], score_threshold: Optional[float] = None,
+               top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]:
+        if self.rerank_runner:
+            documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user)
+
+        if self.reorder_runner:
+            documents = self.reorder_runner.run(documents)
+
+        return documents
+
+    def _get_rerank_runner(self, reranking_model: dict, tenant_id: str) -> Optional[RerankRunner]:
+        if reranking_model:
+            try:
+                model_manager = ModelManager()
+                rerank_model_instance = model_manager.get_model_instance(
+                    tenant_id=tenant_id,
+                    provider=reranking_model['reranking_provider_name'],
+                    model_type=ModelType.RERANK,
+                    model=reranking_model['reranking_model_name']
+                )
+            except InvokeAuthorizationError:
+                return None
+            return RerankRunner(rerank_model_instance)
+        return None
+
+    def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]:
+        if reorder_enabled:
+            return ReorderRunner()
+        return None
+
+

+ 19 - 0
api/core/rag/data_post_processor/reorder.py

@@ -0,0 +1,19 @@
+
+from langchain.schema import Document
+
+
+class ReorderRunner:
+
+    def run(self, documents: list[Document]) -> list[Document]:
+        # Retrieve elements from odd indices (0, 2, 4, etc.) of the documents list
+        odd_elements = documents[::2]
+
+        # Retrieve elements from even indices (1, 3, 5, etc.) of the documents list
+        even_elements = documents[1::2]
+
+        # Reverse the list of elements from even indices
+        even_elements_reversed = even_elements[::-1]
+
+        new_documents = odd_elements + even_elements_reversed
+
+        return new_documents

+ 0 - 0
api/core/rag/datasource/__init__.py


+ 21 - 0
api/core/rag/datasource/entity/embedding.py

@@ -0,0 +1,21 @@
+from abc import ABC, abstractmethod
+
+
+class Embeddings(ABC):
+    """Interface for embedding models."""
+
+    @abstractmethod
+    def embed_documents(self, texts: list[str]) -> list[list[float]]:
+        """Embed search docs."""
+
+    @abstractmethod
+    def embed_query(self, text: str) -> list[float]:
+        """Embed query text."""
+
+    async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
+        """Asynchronous Embed search docs."""
+        raise NotImplementedError
+
+    async def aembed_query(self, text: str) -> list[float]:
+        """Asynchronous Embed query text."""
+        raise NotImplementedError

+ 0 - 0
api/core/rag/datasource/keyword/__init__.py


+ 0 - 0
api/core/rag/datasource/keyword/jieba/__init__.py


+ 18 - 90
api/core/index/keyword_table_index/keyword_table_index.py → api/core/rag/datasource/keyword/jieba/jieba.py

@@ -2,11 +2,11 @@ import json
 from collections import defaultdict
 from typing import Any, Optional
 
-from langchain.schema import BaseRetriever, Document
-from pydantic import BaseModel, Extra, Field
+from pydantic import BaseModel
 
-from core.index.base import BaseIndex
-from core.index.keyword_table_index.jieba_keyword_table_handler import JiebaKeywordTableHandler
+from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
+from core.rag.datasource.keyword.keyword_base import BaseKeyword
+from core.rag.models.document import Document
 from extensions.ext_database import db
 from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment
 
@@ -15,59 +15,19 @@ class KeywordTableConfig(BaseModel):
     max_keywords_per_chunk: int = 10
 
 
-class KeywordTableIndex(BaseIndex):
-    def __init__(self, dataset: Dataset, config: KeywordTableConfig = KeywordTableConfig()):
+class Jieba(BaseKeyword):
+    def __init__(self, dataset: Dataset):
         super().__init__(dataset)
-        self._config = config
+        self._config = KeywordTableConfig()
 
-    def create(self, texts: list[Document], **kwargs) -> BaseIndex:
+    def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
         keyword_table_handler = JiebaKeywordTableHandler()
-        keyword_table = {}
-        for text in texts:
-            keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
-            self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
-            keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
-
-        dataset_keyword_table = DatasetKeywordTable(
-            dataset_id=self.dataset.id,
-            keyword_table=json.dumps({
-                '__type__': 'keyword_table',
-                '__data__': {
-                    "index_id": self.dataset.id,
-                    "summary": None,
-                    "table": {}
-                }
-            }, cls=SetEncoder)
-        )
-        db.session.add(dataset_keyword_table)
-        db.session.commit()
-
-        self._save_dataset_keyword_table(keyword_table)
-
-        return self
-
-    def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
-        keyword_table_handler = JiebaKeywordTableHandler()
-        keyword_table = {}
+        keyword_table = self._get_dataset_keyword_table()
         for text in texts:
             keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
             self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
             keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
 
-        dataset_keyword_table = DatasetKeywordTable(
-            dataset_id=self.dataset.id,
-            keyword_table=json.dumps({
-                '__type__': 'keyword_table',
-                '__data__': {
-                    "index_id": self.dataset.id,
-                    "summary": None,
-                    "table": {}
-                }
-            }, cls=SetEncoder)
-        )
-        db.session.add(dataset_keyword_table)
-        db.session.commit()
-
         self._save_dataset_keyword_table(keyword_table)
 
         return self
@@ -76,8 +36,13 @@ class KeywordTableIndex(BaseIndex):
         keyword_table_handler = JiebaKeywordTableHandler()
 
         keyword_table = self._get_dataset_keyword_table()
-        for text in texts:
-            keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
+        keywords_list = kwargs.get('keywords_list', None)
+        for i in range(len(texts)):
+            text = texts[i]
+            if keywords_list:
+                keywords = keywords_list[i]
+            else:
+                keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
             self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
             keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
 
@@ -107,20 +72,13 @@ class KeywordTableIndex(BaseIndex):
 
         self._save_dataset_keyword_table(keyword_table)
 
-    def delete_by_metadata_field(self, key: str, value: str):
-        pass
-
-    def get_retriever(self, **kwargs: Any) -> BaseRetriever:
-        return KeywordTableRetriever(index=self, **kwargs)
-
     def search(
             self, query: str,
             **kwargs: Any
     ) -> list[Document]:
         keyword_table = self._get_dataset_keyword_table()
 
-        search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
-        k = search_kwargs.get('k') if search_kwargs.get('k') else 4
+        k = kwargs.get('top_k', 4)
 
         sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k)
 
@@ -150,12 +108,6 @@ class KeywordTableIndex(BaseIndex):
             db.session.delete(dataset_keyword_table)
             db.session.commit()
 
-    def delete_by_group_id(self, group_id: str) -> None:
-        dataset_keyword_table = self.dataset.dataset_keyword_table
-        if dataset_keyword_table:
-            db.session.delete(dataset_keyword_table)
-            db.session.commit()
-
     def _save_dataset_keyword_table(self, keyword_table):
         keyword_table_dict = {
             '__type__': 'keyword_table',
@@ -242,6 +194,7 @@ class KeywordTableIndex(BaseIndex):
         ).first()
         if document_segment:
             document_segment.keywords = keywords
+            db.session.add(document_segment)
             db.session.commit()
 
     def create_segment_keywords(self, node_id: str, keywords: list[str]):
@@ -272,31 +225,6 @@ class KeywordTableIndex(BaseIndex):
         self._save_dataset_keyword_table(keyword_table)
 
 
-class KeywordTableRetriever(BaseRetriever, BaseModel):
-    index: KeywordTableIndex
-    search_kwargs: dict = Field(default_factory=dict)
-
-    class Config:
-        """Configuration for this pydantic object."""
-
-        extra = Extra.forbid
-        arbitrary_types_allowed = True
-
-    def get_relevant_documents(self, query: str) -> list[Document]:
-        """Get documents relevant for a query.
-
-        Args:
-            query: string to find relevant documents for
-
-        Returns:
-            List of relevant documents
-        """
-        return self.index.search(query, **self.search_kwargs)
-
-    async def aget_relevant_documents(self, query: str) -> list[Document]:
-        raise NotImplementedError("KeywordTableRetriever does not support async")
-
-
 class SetEncoder(json.JSONEncoder):
     def default(self, obj):
         if isinstance(obj, set):

+ 1 - 1
api/core/index/keyword_table_index/jieba_keyword_table_handler.py → api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py

@@ -3,7 +3,7 @@ import re
 import jieba
 from jieba.analyse import default_tfidf
 
-from core.index.keyword_table_index.stopwords import STOPWORDS
+from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
 
 
 class JiebaKeywordTableHandler:

+ 0 - 0
api/core/index/keyword_table_index/stopwords.py → api/core/rag/datasource/keyword/jieba/stopwords.py


+ 5 - 23
api/core/index/base.py → api/core/rag/datasource/keyword/keyword_base.py

@@ -3,22 +3,17 @@ from __future__ import annotations
 from abc import ABC, abstractmethod
 from typing import Any
 
-from langchain.schema import BaseRetriever, Document
-
+from core.rag.models.document import Document
 from models.dataset import Dataset
 
 
-class BaseIndex(ABC):
+class BaseKeyword(ABC):
 
     def __init__(self, dataset: Dataset):
         self.dataset = dataset
 
     @abstractmethod
-    def create(self, texts: list[Document], **kwargs) -> BaseIndex:
-        raise NotImplementedError
-
-    @abstractmethod
-    def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
+    def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
         raise NotImplementedError
 
     @abstractmethod
@@ -34,31 +29,18 @@ class BaseIndex(ABC):
         raise NotImplementedError
 
     @abstractmethod
-    def delete_by_metadata_field(self, key: str, value: str) -> None:
-        raise NotImplementedError
-
-    @abstractmethod
-    def delete_by_group_id(self, group_id: str) -> None:
-        raise NotImplementedError
-
-    @abstractmethod
-    def delete_by_document_id(self, document_id: str):
+    def delete_by_document_id(self, document_id: str) -> None:
         raise NotImplementedError
 
-    @abstractmethod
-    def get_retriever(self, **kwargs: Any) -> BaseRetriever:
+    def delete(self) -> None:
         raise NotImplementedError
 
-    @abstractmethod
     def search(
             self, query: str,
             **kwargs: Any
     ) -> list[Document]:
         raise NotImplementedError
 
-    def delete(self) -> None:
-        raise NotImplementedError
-
     def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
         for text in texts:
             doc_id = text.metadata['doc_id']

+ 60 - 0
api/core/rag/datasource/keyword/keyword_factory.py

@@ -0,0 +1,60 @@
+from typing import Any, cast
+
+from flask import current_app
+
+from core.rag.datasource.keyword.jieba.jieba import Jieba
+from core.rag.datasource.keyword.keyword_base import BaseKeyword
+from core.rag.models.document import Document
+from models.dataset import Dataset
+
+
+class Keyword:
+    def __init__(self, dataset: Dataset):
+        self._dataset = dataset
+        self._keyword_processor = self._init_keyword()
+
+    def _init_keyword(self) -> BaseKeyword:
+        config = cast(dict, current_app.config)
+        keyword_type = config.get('KEYWORD_STORE')
+
+        if not keyword_type:
+            raise ValueError("Keyword store must be specified.")
+
+        if keyword_type == "jieba":
+            return Jieba(
+                dataset=self._dataset
+            )
+        else:
+            raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
+
+    def create(self, texts: list[Document], **kwargs):
+        self._keyword_processor.create(texts, **kwargs)
+
+    def add_texts(self, texts: list[Document], **kwargs):
+        self._keyword_processor.add_texts(texts, **kwargs)
+
+    def text_exists(self, id: str) -> bool:
+        return self._keyword_processor.text_exists(id)
+
+    def delete_by_ids(self, ids: list[str]) -> None:
+        self._keyword_processor.delete_by_ids(ids)
+
+    def delete_by_document_id(self, document_id: str) -> None:
+        self._keyword_processor.delete_by_document_id(document_id)
+
+    def delete(self) -> None:
+        self._keyword_processor.delete()
+
+    def search(
+            self, query: str,
+            **kwargs: Any
+    ) -> list[Document]:
+        return self._keyword_processor.search(query, **kwargs)
+
+    def __getattr__(self, name):
+        if self._keyword_processor is not None:
+            method = getattr(self._keyword_processor, name)
+            if callable(method):
+                return method
+
+        raise AttributeError(f"'Keyword' object has no attribute '{name}'")

+ 165 - 0
api/core/rag/datasource/retrieval_service.py

@@ -0,0 +1,165 @@
+import threading
+from typing import Optional
+
+from flask import Flask, current_app
+from flask_login import current_user
+
+from core.rag.data_post_processor.data_post_processor import DataPostProcessor
+from core.rag.datasource.keyword.keyword_factory import Keyword
+from core.rag.datasource.vdb.vector_factory import Vector
+from extensions.ext_database import db
+from models.dataset import Dataset
+
+default_retrieval_model = {
+    'search_method': 'semantic_search',
+    'reranking_enable': False,
+    'reranking_model': {
+        'reranking_provider_name': '',
+        'reranking_model_name': ''
+    },
+    'top_k': 2,
+    'score_threshold_enabled': False
+}
+
+
+class RetrievalService:
+
+    @classmethod
+    def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
+                 top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None):
+        all_documents = []
+        threads = []
+        # retrieval_model source with keyword
+        if retrival_method == 'keyword_search':
+            keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={
+                'flask_app': current_app._get_current_object(),
+                'dataset_id': dataset_id,
+                'query': query,
+                'top_k': top_k
+            })
+            threads.append(keyword_thread)
+            keyword_thread.start()
+        # retrieval_model source with semantic
+        if retrival_method == 'semantic_search' or retrival_method == 'hybrid_search':
+            embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
+                'flask_app': current_app._get_current_object(),
+                'dataset_id': dataset_id,
+                'query': query,
+                'top_k': top_k,
+                'score_threshold': score_threshold,
+                'reranking_model': reranking_model,
+                'all_documents': all_documents,
+                'retrival_method': retrival_method
+            })
+            threads.append(embedding_thread)
+            embedding_thread.start()
+
+        # retrieval source with full text
+        if retrival_method == 'full_text_search' or retrival_method == 'hybrid_search':
+            full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
+                'flask_app': current_app._get_current_object(),
+                'dataset_id': dataset_id,
+                'query': query,
+                'retrival_method': retrival_method,
+                'score_threshold': score_threshold,
+                'top_k': top_k,
+                'reranking_model': reranking_model,
+                'all_documents': all_documents
+            })
+            threads.append(full_text_index_thread)
+            full_text_index_thread.start()
+
+        for thread in threads:
+            thread.join()
+
+        if retrival_method == 'hybrid_search':
+            data_post_processor = DataPostProcessor(str(current_user.current_tenant_id), reranking_model, False)
+            all_documents = data_post_processor.invoke(
+                query=query,
+                documents=all_documents,
+                score_threshold=score_threshold,
+                top_n=top_k
+            )
+        return all_documents
+
+    @classmethod
+    def keyword_search(cls, flask_app: Flask, dataset_id: str, query: str,
+                       top_k: int, all_documents: list):
+        with flask_app.app_context():
+            dataset = db.session.query(Dataset).filter(
+                Dataset.id == dataset_id
+            ).first()
+
+            keyword = Keyword(
+                dataset=dataset
+            )
+
+            documents = keyword.search(
+                query,
+                k=top_k
+            )
+            all_documents.extend(documents)
+
+    @classmethod
+    def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str,
+                         top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
+                         all_documents: list, retrival_method: str):
+        with flask_app.app_context():
+            dataset = db.session.query(Dataset).filter(
+                Dataset.id == dataset_id
+            ).first()
+
+            vector = Vector(
+                dataset=dataset
+            )
+
+            documents = vector.search_by_vector(
+                query,
+                search_type='similarity_score_threshold',
+                k=top_k,
+                score_threshold=score_threshold,
+                filter={
+                    'group_id': [dataset.id]
+                }
+            )
+
+            if documents:
+                if reranking_model and retrival_method == 'semantic_search':
+                    data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
+                    all_documents.extend(data_post_processor.invoke(
+                        query=query,
+                        documents=documents,
+                        score_threshold=score_threshold,
+                        top_n=len(documents)
+                    ))
+                else:
+                    all_documents.extend(documents)
+
+    @classmethod
+    def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str,
+                               top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
+                               all_documents: list, retrival_method: str):
+        with flask_app.app_context():
+            dataset = db.session.query(Dataset).filter(
+                Dataset.id == dataset_id
+            ).first()
+
+            vector_processor = Vector(
+                dataset=dataset,
+            )
+
+            documents = vector_processor.search_by_full_text(
+                query,
+                top_k=top_k
+            )
+            if documents:
+                if reranking_model and retrival_method == 'full_text_search':
+                    data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
+                    all_documents.extend(data_post_processor.invoke(
+                        query=query,
+                        documents=documents,
+                        score_threshold=score_threshold,
+                        top_n=len(documents)
+                    ))
+                else:
+                    all_documents.extend(documents)

+ 0 - 0
api/core/rag/datasource/vdb/__init__.py


+ 10 - 0
api/core/rag/datasource/vdb/field.py

@@ -0,0 +1,10 @@
+from enum import Enum
+
+
+class Field(Enum):
+    CONTENT_KEY = "page_content"
+    METADATA_KEY = "metadata"
+    GROUP_KEY = "group_id"
+    VECTOR = "vector"
+    TEXT_KEY = "text"
+    PRIMARY_KEY = " id"

+ 0 - 0
api/core/rag/datasource/vdb/milvus/__init__.py


+ 214 - 0
api/core/rag/datasource/vdb/milvus/milvus_vector.py

@@ -0,0 +1,214 @@
+import logging
+from typing import Any, Optional
+from uuid import uuid4
+
+from pydantic import BaseModel, root_validator
+from pymilvus import MilvusClient, MilvusException, connections
+
+from core.rag.datasource.vdb.field import Field
+from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.models.document import Document
+
+logger = logging.getLogger(__name__)
+
+
+class MilvusConfig(BaseModel):
+    host: str
+    port: int
+    user: str
+    password: str
+    secure: bool = False
+    batch_size: int = 100
+
+    @root_validator()
+    def validate_config(cls, values: dict) -> dict:
+        if not values['host']:
+            raise ValueError("config MILVUS_HOST is required")
+        if not values['port']:
+            raise ValueError("config MILVUS_PORT is required")
+        if not values['user']:
+            raise ValueError("config MILVUS_USER is required")
+        if not values['password']:
+            raise ValueError("config MILVUS_PASSWORD is required")
+        return values
+
+    def to_milvus_params(self):
+        return {
+            'host': self.host,
+            'port': self.port,
+            'user': self.user,
+            'password': self.password,
+            'secure': self.secure
+        }
+
+
+class MilvusVector(BaseVector):
+
+    def __init__(self, collection_name: str, config: MilvusConfig):
+        super().__init__(collection_name)
+        self._client_config = config
+        self._client = self._init_client(config)
+        self._consistency_level = 'Session'
+        self._fields = []
+
+    def get_type(self) -> str:
+        return 'milvus'
+
+    def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
+        index_params = {
+            'metric_type': 'IP',
+            'index_type': "HNSW",
+            'params': {"M": 8, "efConstruction": 64}
+        }
+        metadatas = [d.metadata for d in texts]
+
+        # Grab the existing collection if it exists
+        from pymilvus import utility
+        alias = uuid4().hex
+        if self._client_config.secure:
+            uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
+        else:
+            uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
+        connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
+        if not utility.has_collection(self._collection_name, using=alias):
+            self.create_collection(embeddings, metadatas, index_params)
+        self.add_texts(texts, embeddings)
+
+    def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
+        insert_dict_list = []
+        for i in range(len(documents)):
+            insert_dict = {
+                Field.CONTENT_KEY.value: documents[i].page_content,
+                Field.VECTOR.value: embeddings[i],
+                Field.METADATA_KEY.value: documents[i].metadata
+            }
+            insert_dict_list.append(insert_dict)
+        # Total insert count
+        total_count = len(insert_dict_list)
+
+        pks: list[str] = []
+
+        for i in range(0, total_count, 1000):
+            batch_insert_list = insert_dict_list[i:i + 1000]
+            # Insert into the collection.
+            try:
+                ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
+                pks.extend(ids)
+            except MilvusException as e:
+                logger.error(
+                    "Failed to insert batch starting at entity: %s/%s", i, total_count
+                )
+                raise e
+        return pks
+
+    def delete_by_document_id(self, document_id: str):
+
+        ids = self.get_ids_by_metadata_field('document_id', document_id)
+        if ids:
+            self._client.delete(collection_name=self._collection_name, pks=ids)
+
+    def get_ids_by_metadata_field(self, key: str, value: str):
+        result = self._client.query(collection_name=self._collection_name,
+                                    filter=f'metadata["{key}"] == "{value}"',
+                                    output_fields=["id"])
+        if result:
+            return [item["id"] for item in result]
+        else:
+            return None
+
+    def delete_by_metadata_field(self, key: str, value: str):
+
+        ids = self.get_ids_by_metadata_field(key, value)
+        if ids:
+            self._client.delete(collection_name=self._collection_name, pks=ids)
+
+    def delete_by_ids(self, doc_ids: list[str]) -> None:
+
+        self._client.delete(collection_name=self._collection_name, pks=doc_ids)
+
+    def delete(self) -> None:
+
+        from pymilvus import utility
+        utility.drop_collection(self._collection_name, None)
+
+    def text_exists(self, id: str) -> bool:
+
+        result = self._client.query(collection_name=self._collection_name,
+                                    filter=f'metadata["doc_id"] == "{id}"',
+                                    output_fields=["id"])
+
+        return len(result) > 0
+
+    def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
+
+        # Set search parameters.
+        results = self._client.search(collection_name=self._collection_name,
+                                      data=[query_vector],
+                                      limit=kwargs.get('top_k', 4),
+                                      output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
+                                      )
+        # Organize results.
+        docs = []
+        for result in results[0]:
+            metadata = result['entity'].get(Field.METADATA_KEY.value)
+            metadata['score'] = result['distance']
+            score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
+            if result['distance'] > score_threshold:
+                doc = Document(page_content=result['entity'].get(Field.CONTENT_KEY.value),
+                               metadata=metadata)
+                docs.append(doc)
+        return docs
+
+    def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
+        # milvus/zilliz doesn't support bm25 search
+        return []
+
+    def create_collection(
+            self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
+    ) -> str:
+        from pymilvus import CollectionSchema, DataType, FieldSchema
+        from pymilvus.orm.types import infer_dtype_bydata
+
+        # Determine embedding dim
+        dim = len(embeddings[0])
+        fields = []
+        if metadatas:
+            fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
+
+        # Create the text field
+        fields.append(
+            FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)
+        )
+        # Create the primary key field
+        fields.append(
+            FieldSchema(
+                Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True
+            )
+        )
+        # Create the vector field, supports binary or float vectors
+        fields.append(
+            FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)
+        )
+
+        # Create the schema for the collection
+        schema = CollectionSchema(fields)
+
+        for x in schema.fields:
+            self._fields.append(x.name)
+        # Since primary field is auto-id, no need to track it
+        self._fields.remove(Field.PRIMARY_KEY.value)
+
+        # Create the collection
+        collection_name = self._collection_name
+        self._client.create_collection_with_schema(collection_name=collection_name,
+                                                   schema=schema, index_param=index_params,
+                                                   consistency_level=self._consistency_level)
+        return collection_name
+
+    def _init_client(self, config) -> MilvusClient:
+        if config.secure:
+            uri = "https://" + str(config.host) + ":" + str(config.port)
+        else:
+            uri = "http://" + str(config.host) + ":" + str(config.port)
+        client = MilvusClient(uri=uri, user=config.user, password=config.password)
+        return client

+ 0 - 0
api/core/rag/datasource/vdb/qdrant/__init__.py


+ 360 - 0
api/core/rag/datasource/vdb/qdrant/qdrant_vector.py

@@ -0,0 +1,360 @@
+import os
+import uuid
+from collections.abc import Generator, Iterable, Sequence
+from itertools import islice
+from typing import TYPE_CHECKING, Any, Optional, Union, cast
+
+import qdrant_client
+from pydantic import BaseModel
+from qdrant_client.http import models as rest
+from qdrant_client.http.models import (
+    FilterSelector,
+    HnswConfigDiff,
+    PayloadSchemaType,
+    TextIndexParams,
+    TextIndexType,
+    TokenizerType,
+)
+from qdrant_client.local.qdrant_local import QdrantLocal
+
+from core.rag.datasource.vdb.field import Field
+from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.models.document import Document
+
+if TYPE_CHECKING:
+    from qdrant_client import grpc  # noqa
+    from qdrant_client.conversions import common_types
+    from qdrant_client.http import models as rest
+
+    DictFilter = dict[str, Union[str, int, bool, dict, list]]
+    MetadataFilter = Union[DictFilter, common_types.Filter]
+
+
+class QdrantConfig(BaseModel):
+    endpoint: str
+    api_key: Optional[str]
+    timeout: float = 20
+    root_path: Optional[str]
+
+    def to_qdrant_params(self):
+        if self.endpoint and self.endpoint.startswith('path:'):
+            path = self.endpoint.replace('path:', '')
+            if not os.path.isabs(path):
+                path = os.path.join(self.root_path, path)
+
+            return {
+                'path': path
+            }
+        else:
+            return {
+                'url': self.endpoint,
+                'api_key': self.api_key,
+                'timeout': self.timeout
+            }
+
+
+class QdrantVector(BaseVector):
+
+    def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = 'Cosine'):
+        super().__init__(collection_name)
+        self._client_config = config
+        self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params())
+        self._distance_func = distance_func.upper()
+        self._group_id = group_id
+
+    def get_type(self) -> str:
+        return 'qdrant'
+
+    def to_index_struct(self) -> dict:
+        return {
+            "type": self.get_type(),
+            "vector_store": {"class_prefix": self._collection_name}
+        }
+
+    def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
+        if texts:
+            # get embedding vector size
+            vector_size = len(embeddings[0])
+            # get collection name
+            collection_name = self._collection_name
+            collection_name = collection_name or uuid.uuid4().hex
+            all_collection_name = []
+            collections_response = self._client.get_collections()
+            collection_list = collections_response.collections
+            for collection in collection_list:
+                all_collection_name.append(collection.name)
+            if collection_name not in all_collection_name:
+                # create collection
+                self.create_collection(collection_name, vector_size)
+
+            self.add_texts(texts, embeddings, **kwargs)
+
+    def create_collection(self, collection_name: str, vector_size: int):
+        from qdrant_client.http import models as rest
+        vectors_config = rest.VectorParams(
+            size=vector_size,
+            distance=rest.Distance[self._distance_func],
+        )
+        hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
+                                     max_indexing_threads=0, on_disk=False)
+        self._client.recreate_collection(
+            collection_name=collection_name,
+            vectors_config=vectors_config,
+            hnsw_config=hnsw_config,
+            timeout=int(self._client_config.timeout),
+        )
+
+        # create payload index
+        self._client.create_payload_index(collection_name, Field.GROUP_KEY.value,
+                                          field_schema=PayloadSchemaType.KEYWORD,
+                                          field_type=PayloadSchemaType.KEYWORD)
+        # creat full text index
+        text_index_params = TextIndexParams(
+            type=TextIndexType.TEXT,
+            tokenizer=TokenizerType.MULTILINGUAL,
+            min_token_len=2,
+            max_token_len=20,
+            lowercase=True
+        )
+        self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value,
+                                          field_schema=text_index_params)
+
+    def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
+        uuids = self._get_uuids(documents)
+        texts = [d.page_content for d in documents]
+        metadatas = [d.metadata for d in documents]
+
+        added_ids = []
+        for batch_ids, points in self._generate_rest_batches(
+                texts, embeddings, metadatas, uuids, 64, self._group_id
+        ):
+            self._client.upsert(
+                collection_name=self._collection_name, points=points
+            )
+            added_ids.extend(batch_ids)
+
+        return added_ids
+
+    def _generate_rest_batches(
+            self,
+            texts: Iterable[str],
+            embeddings: list[list[float]],
+            metadatas: Optional[list[dict]] = None,
+            ids: Optional[Sequence[str]] = None,
+            batch_size: int = 64,
+            group_id: Optional[str] = None,
+    ) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]:
+        from qdrant_client.http import models as rest
+        texts_iterator = iter(texts)
+        embeddings_iterator = iter(embeddings)
+        metadatas_iterator = iter(metadatas or [])
+        ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
+        while batch_texts := list(islice(texts_iterator, batch_size)):
+            # Take the corresponding metadata and id for each text in a batch
+            batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
+            batch_ids = list(islice(ids_iterator, batch_size))
+
+            # Generate the embeddings for all the texts in a batch
+            batch_embeddings = list(islice(embeddings_iterator, batch_size))
+
+            points = [
+                rest.PointStruct(
+                    id=point_id,
+                    vector=vector,
+                    payload=payload,
+                )
+                for point_id, vector, payload in zip(
+                    batch_ids,
+                    batch_embeddings,
+                    self._build_payloads(
+                        batch_texts,
+                        batch_metadatas,
+                        Field.CONTENT_KEY.value,
+                        Field.METADATA_KEY.value,
+                        group_id,
+                        Field.GROUP_KEY.value,
+                    ),
+                )
+            ]
+
+            yield batch_ids, points
+
+    @classmethod
+    def _build_payloads(
+            cls,
+            texts: Iterable[str],
+            metadatas: Optional[list[dict]],
+            content_payload_key: str,
+            metadata_payload_key: str,
+            group_id: str,
+            group_payload_key: str
+    ) -> list[dict]:
+        payloads = []
+        for i, text in enumerate(texts):
+            if text is None:
+                raise ValueError(
+                    "At least one of the texts is None. Please remove it before "
+                    "calling .from_texts or .add_texts on Qdrant instance."
+                )
+            metadata = metadatas[i] if metadatas is not None else None
+            payloads.append(
+                {
+                    content_payload_key: text,
+                    metadata_payload_key: metadata,
+                    group_payload_key: group_id
+                }
+            )
+
+        return payloads
+
+    def delete_by_metadata_field(self, key: str, value: str):
+
+        from qdrant_client.http import models
+
+        filter = models.Filter(
+            must=[
+                models.FieldCondition(
+                    key=f"metadata.{key}",
+                    match=models.MatchValue(value=value),
+                ),
+            ],
+        )
+
+        self._reload_if_needed()
+
+        self._client.delete(
+            collection_name=self._collection_name,
+            points_selector=FilterSelector(
+                filter=filter
+            ),
+        )
+
+    def delete(self):
+        from qdrant_client.http import models
+        filter = models.Filter(
+            must=[
+                models.FieldCondition(
+                    key="group_id",
+                    match=models.MatchValue(value=self._group_id),
+                ),
+            ],
+        )
+        self._client.delete(
+            collection_name=self._collection_name,
+            points_selector=FilterSelector(
+                filter=filter
+            ),
+        )
+
+    def delete_by_ids(self, ids: list[str]) -> None:
+
+        from qdrant_client.http import models
+        for node_id in ids:
+            filter = models.Filter(
+                must=[
+                    models.FieldCondition(
+                        key="metadata.doc_id",
+                        match=models.MatchValue(value=node_id),
+                    ),
+                ],
+            )
+            self._client.delete(
+                collection_name=self._collection_name,
+                points_selector=FilterSelector(
+                    filter=filter
+                ),
+            )
+
+    def text_exists(self, id: str) -> bool:
+        response = self._client.retrieve(
+            collection_name=self._collection_name,
+            ids=[id]
+        )
+
+        return len(response) > 0
+
+    def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
+        from qdrant_client.http import models
+        filter = models.Filter(
+            must=[
+                models.FieldCondition(
+                    key="group_id",
+                    match=models.MatchValue(value=self._group_id),
+                ),
+            ],
+        )
+        results = self._client.search(
+            collection_name=self._collection_name,
+            query_vector=query_vector,
+            query_filter=filter,
+            limit=kwargs.get("top_k", 4),
+            with_payload=True,
+            with_vectors=True,
+            score_threshold=kwargs.get("score_threshold", .0)
+        )
+        docs = []
+        for result in results:
+            metadata = result.payload.get(Field.METADATA_KEY.value) or {}
+            # duplicate check score threshold
+            score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
+            if result.score > score_threshold:
+                metadata['score'] = result.score
+                doc = Document(
+                    page_content=result.payload.get(Field.CONTENT_KEY.value),
+                    metadata=metadata,
+                )
+                docs.append(doc)
+        return docs
+
+    def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
+        """Return docs most similar by bm25.
+        Returns:
+            List of documents most similar to the query text and distance for each.
+        """
+        from qdrant_client.http import models
+        scroll_filter = models.Filter(
+            must=[
+                models.FieldCondition(
+                    key="group_id",
+                    match=models.MatchValue(value=self._group_id),
+                ),
+                models.FieldCondition(
+                    key="page_content",
+                    match=models.MatchText(text=query),
+                )
+            ]
+        )
+        response = self._client.scroll(
+            collection_name=self._collection_name,
+            scroll_filter=scroll_filter,
+            limit=kwargs.get('top_k', 2),
+            with_payload=True,
+            with_vectors=True
+
+        )
+        results = response[0]
+        documents = []
+        for result in results:
+            if result:
+                documents.append(self._document_from_scored_point(
+                    result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value
+                ))
+
+        return documents
+
+    def _reload_if_needed(self):
+        if isinstance(self._client, QdrantLocal):
+            self._client = cast(QdrantLocal, self._client)
+            self._client._load()
+
+    @classmethod
+    def _document_from_scored_point(
+            cls,
+            scored_point: Any,
+            content_payload_key: str,
+            metadata_payload_key: str,
+    ) -> Document:
+        return Document(
+            page_content=scored_point.payload.get(content_payload_key),
+            metadata=scored_point.payload.get(metadata_payload_key) or {},
+        )

+ 62 - 0
api/core/rag/datasource/vdb/vector_base.py

@@ -0,0 +1,62 @@
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from typing import Any
+
+from core.rag.models.document import Document
+
+
+class BaseVector(ABC):
+
+    def __init__(self, collection_name: str):
+        self._collection_name = collection_name
+
+    @abstractmethod
+    def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
+        raise NotImplementedError
+
+    @abstractmethod
+    def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
+        raise NotImplementedError
+
+    @abstractmethod
+    def text_exists(self, id: str) -> bool:
+        raise NotImplementedError
+
+    @abstractmethod
+    def delete_by_ids(self, ids: list[str]) -> None:
+        raise NotImplementedError
+
+    @abstractmethod
+    def delete_by_metadata_field(self, key: str, value: str) -> None:
+        raise NotImplementedError
+
+    @abstractmethod
+    def search_by_vector(
+            self,
+            query_vector: list[float],
+            **kwargs: Any
+    ) -> list[Document]:
+        raise NotImplementedError
+
+    @abstractmethod
+    def search_by_full_text(
+            self, query: str,
+            **kwargs: Any
+    ) -> list[Document]:
+        raise NotImplementedError
+
+    def delete(self) -> None:
+        raise NotImplementedError
+
+    def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
+        for text in texts:
+            doc_id = text.metadata['doc_id']
+            exists_duplicate_node = self.text_exists(doc_id)
+            if exists_duplicate_node:
+                texts.remove(text)
+
+        return texts
+
+    def _get_uuids(self, texts: list[Document]) -> list[str]:
+        return [text.metadata['doc_id'] for text in texts]

+ 171 - 0
api/core/rag/datasource/vdb/vector_factory.py

@@ -0,0 +1,171 @@
+from typing import Any, cast
+
+from flask import current_app
+
+from core.embedding.cached_embedding import CacheEmbedding
+from core.model_manager import ModelManager
+from core.model_runtime.entities.model_entities import ModelType
+from core.rag.datasource.entity.embedding import Embeddings
+from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.models.document import Document
+from extensions.ext_database import db
+from models.dataset import Dataset, DatasetCollectionBinding
+
+
+class Vector:
+    def __init__(self, dataset: Dataset, attributes: list = None):
+        if attributes is None:
+            attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
+        self._dataset = dataset
+        self._embeddings = self._get_embeddings()
+        self._attributes = attributes
+        self._vector_processor = self._init_vector()
+
+    def _init_vector(self) -> BaseVector:
+        config = cast(dict, current_app.config)
+        vector_type = config.get('VECTOR_STORE')
+
+        if self._dataset.index_struct_dict:
+            vector_type = self._dataset.index_struct_dict['type']
+
+        if not vector_type:
+            raise ValueError("Vector store must be specified.")
+
+        if vector_type == "weaviate":
+            from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
+            if self._dataset.index_struct_dict:
+                class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
+                collection_name = class_prefix
+            else:
+                dataset_id = self._dataset.id
+                collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
+            return WeaviateVector(
+                collection_name=collection_name,
+                config=WeaviateConfig(
+                    endpoint=config.get('WEAVIATE_ENDPOINT'),
+                    api_key=config.get('WEAVIATE_API_KEY'),
+                    batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
+                ),
+                attributes=self._attributes
+            )
+        elif vector_type == "qdrant":
+            from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector
+            if self._dataset.collection_binding_id:
+                dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
+                    filter(DatasetCollectionBinding.id == self._dataset.collection_binding_id). \
+                    one_or_none()
+                if dataset_collection_binding:
+                    collection_name = dataset_collection_binding.collection_name
+                else:
+                    raise ValueError('Dataset Collection Bindings is not exist!')
+            else:
+                if self._dataset.index_struct_dict:
+                    class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
+                    collection_name = class_prefix
+                else:
+                    dataset_id = self._dataset.id
+                    collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
+
+            return QdrantVector(
+                collection_name=collection_name,
+                group_id=self._dataset.id,
+                config=QdrantConfig(
+                    endpoint=config.get('QDRANT_URL'),
+                    api_key=config.get('QDRANT_API_KEY'),
+                    root_path=current_app.root_path,
+                    timeout=config.get('QDRANT_CLIENT_TIMEOUT')
+                )
+            )
+        elif vector_type == "milvus":
+            from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector
+            if self._dataset.index_struct_dict:
+                class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
+                collection_name = class_prefix
+            else:
+                dataset_id = self._dataset.id
+                collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
+            return MilvusVector(
+                collection_name=collection_name,
+                config=MilvusConfig(
+                    host=config.get('MILVUS_HOST'),
+                    port=config.get('MILVUS_PORT'),
+                    user=config.get('MILVUS_USER'),
+                    password=config.get('MILVUS_PASSWORD'),
+                    secure=config.get('MILVUS_SECURE'),
+                )
+            )
+        else:
+            raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
+
+    def create(self, texts: list = None, **kwargs):
+        if texts:
+            embeddings = self._embeddings.embed_documents([document.page_content for document in texts])
+            self._vector_processor.create(
+                texts=texts,
+                embeddings=embeddings,
+                **kwargs
+            )
+
+    def add_texts(self, documents: list[Document], **kwargs):
+        if kwargs.get('duplicate_check', False):
+            documents = self._filter_duplicate_texts(documents)
+        embeddings = self._embeddings.embed_documents([document.page_content for document in documents])
+        self._vector_processor.add_texts(
+            documents=documents,
+            embeddings=embeddings,
+            **kwargs
+        )
+
+    def text_exists(self, id: str) -> bool:
+        return self._vector_processor.text_exists(id)
+
+    def delete_by_ids(self, ids: list[str]) -> None:
+        self._vector_processor.delete_by_ids(ids)
+
+    def delete_by_metadata_field(self, key: str, value: str) -> None:
+        self._vector_processor.delete_by_metadata_field(key, value)
+
+    def search_by_vector(
+            self, query: str,
+            **kwargs: Any
+    ) -> list[Document]:
+        query_vector = self._embeddings.embed_query(query)
+        return self._vector_processor.search_by_vector(query_vector, **kwargs)
+
+    def search_by_full_text(
+            self, query: str,
+            **kwargs: Any
+    ) -> list[Document]:
+        return self._vector_processor.search_by_full_text(query, **kwargs)
+
+    def delete(self) -> None:
+        self._vector_processor.delete()
+
+    def _get_embeddings(self) -> Embeddings:
+        model_manager = ModelManager()
+
+        embedding_model = model_manager.get_model_instance(
+            tenant_id=self._dataset.tenant_id,
+            provider=self._dataset.embedding_model_provider,
+            model_type=ModelType.TEXT_EMBEDDING,
+            model=self._dataset.embedding_model
+
+        )
+        return CacheEmbedding(embedding_model)
+
+    def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
+        for text in texts:
+            doc_id = text.metadata['doc_id']
+            exists_duplicate_node = self.text_exists(doc_id)
+            if exists_duplicate_node:
+                texts.remove(text)
+
+        return texts
+
+    def __getattr__(self, name):
+        if self._vector_processor is not None:
+            method = getattr(self._vector_processor, name)
+            if callable(method):
+                return method
+
+        raise AttributeError(f"'vector_processor' object has no attribute '{name}'")

+ 0 - 0
api/core/rag/datasource/vdb/weaviate/__init__.py


+ 235 - 0
api/core/rag/datasource/vdb/weaviate/weaviate_vector.py

@@ -0,0 +1,235 @@
+import datetime
+from typing import Any, Optional
+
+import requests
+import weaviate
+from pydantic import BaseModel, root_validator
+
+from core.rag.datasource.vdb.field import Field
+from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.models.document import Document
+from models.dataset import Dataset
+
+
+class WeaviateConfig(BaseModel):
+    endpoint: str
+    api_key: Optional[str]
+    batch_size: int = 100
+
+    @root_validator()
+    def validate_config(cls, values: dict) -> dict:
+        if not values['endpoint']:
+            raise ValueError("config WEAVIATE_ENDPOINT is required")
+        return values
+
+
+class WeaviateVector(BaseVector):
+
+    def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
+        super().__init__(collection_name)
+        self._client = self._init_client(config)
+        self._attributes = attributes
+
+    def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
+        auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
+
+        weaviate.connect.connection.has_grpc = False
+
+        try:
+            client = weaviate.Client(
+                url=config.endpoint,
+                auth_client_secret=auth_config,
+                timeout_config=(5, 60),
+                startup_period=None
+            )
+        except requests.exceptions.ConnectionError:
+            raise ConnectionError("Vector database connection error")
+
+        client.batch.configure(
+            # `batch_size` takes an `int` value to enable auto-batching
+            # (`None` is used for manual batching)
+            batch_size=config.batch_size,
+            # dynamically update the `batch_size` based on import speed
+            dynamic=True,
+            # `timeout_retries` takes an `int` value to retry on time outs
+            timeout_retries=3,
+        )
+
+        return client
+
+    def get_type(self) -> str:
+        return 'weaviate'
+
+    def get_collection_name(self, dataset: Dataset) -> str:
+        if dataset.index_struct_dict:
+            class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
+            if not class_prefix.endswith('_Node'):
+                # original class_prefix
+                class_prefix += '_Node'
+
+            return class_prefix
+
+        dataset_id = dataset.id
+        return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
+
+    def to_index_struct(self) -> dict:
+        return {
+            "type": self.get_type(),
+            "vector_store": {"class_prefix": self._collection_name}
+        }
+
+    def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
+
+        schema = self._default_schema(self._collection_name)
+
+        # check whether the index already exists
+        if not self._client.schema.contains(schema):
+            # create collection
+            self._client.schema.create_class(schema)
+        # create vector
+        self.add_texts(texts, embeddings)
+
+    def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
+        uuids = self._get_uuids(documents)
+        texts = [d.page_content for d in documents]
+        metadatas = [d.metadata for d in documents]
+
+        ids = []
+
+        with self._client.batch as batch:
+            for i, text in enumerate(texts):
+                data_properties = {Field.TEXT_KEY.value: text}
+                if metadatas is not None:
+                    for key, val in metadatas[i].items():
+                        data_properties[key] = self._json_serializable(val)
+
+                batch.add_data_object(
+                    data_object=data_properties,
+                    class_name=self._collection_name,
+                    uuid=uuids[i],
+                    vector=embeddings[i] if embeddings else None,
+                )
+                ids.append(uuids[i])
+        return ids
+
+    def delete_by_metadata_field(self, key: str, value: str):
+
+        where_filter = {
+            "operator": "Equal",
+            "path": [key],
+            "valueText": value
+        }
+
+        self._client.batch.delete_objects(
+            class_name=self._collection_name,
+            where=where_filter,
+            output='minimal'
+        )
+
+    def delete(self):
+        self._client.schema.delete_class(self._collection_name)
+
+    def text_exists(self, id: str) -> bool:
+        collection_name = self._collection_name
+        result = self._client.query.get(collection_name).with_additional(["id"]).with_where({
+            "path": ["doc_id"],
+            "operator": "Equal",
+            "valueText": id,
+        }).with_limit(1).do()
+
+        if "errors" in result:
+            raise ValueError(f"Error during query: {result['errors']}")
+
+        entries = result["data"]["Get"][collection_name]
+        if len(entries) == 0:
+            return False
+
+        return True
+
+    def delete_by_ids(self, ids: list[str]) -> None:
+        self._client.data_object.delete(
+            ids,
+            class_name=self._collection_name
+        )
+
+    def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
+        """Look up similar documents by embedding vector in Weaviate."""
+        collection_name = self._collection_name
+        properties = self._attributes
+        properties.append(Field.TEXT_KEY.value)
+        query_obj = self._client.query.get(collection_name, properties)
+
+        vector = {"vector": query_vector}
+        if kwargs.get("where_filter"):
+            query_obj = query_obj.with_where(kwargs.get("where_filter"))
+        result = (
+            query_obj.with_near_vector(vector)
+            .with_limit(kwargs.get("top_k", 4))
+            .with_additional(["vector", "distance"])
+            .do()
+        )
+        if "errors" in result:
+            raise ValueError(f"Error during query: {result['errors']}")
+
+        docs_and_scores = []
+        for res in result["data"]["Get"][collection_name]:
+            text = res.pop(Field.TEXT_KEY.value)
+            score = 1 - res["_additional"]["distance"]
+            docs_and_scores.append((Document(page_content=text, metadata=res), score))
+
+        docs = []
+        for doc, score in docs_and_scores:
+            score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
+            # check score threshold
+            if score > score_threshold:
+                doc.metadata['score'] = score
+                docs.append(doc)
+
+        return docs
+
+    def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
+        """Return docs using BM25F.
+
+        Args:
+            query: Text to look up documents similar to.
+            k: Number of Documents to return. Defaults to 4.
+
+        Returns:
+            List of Documents most similar to the query.
+        """
+        collection_name = self._collection_name
+        content: dict[str, Any] = {"concepts": [query]}
+        properties = self._attributes
+        properties.append(Field.TEXT_KEY.value)
+        if kwargs.get("search_distance"):
+            content["certainty"] = kwargs.get("search_distance")
+        query_obj = self._client.query.get(collection_name, properties)
+        if kwargs.get("where_filter"):
+            query_obj = query_obj.with_where(kwargs.get("where_filter"))
+        if kwargs.get("additional"):
+            query_obj = query_obj.with_additional(kwargs.get("additional"))
+        properties = ['text']
+        result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do()
+        if "errors" in result:
+            raise ValueError(f"Error during query: {result['errors']}")
+        docs = []
+        for res in result["data"]["Get"][collection_name]:
+            text = res.pop(Field.TEXT_KEY.value)
+            docs.append(Document(page_content=text, metadata=res))
+        return docs
+
+    def _default_schema(self, index_name: str) -> dict:
+        return {
+            "class": index_name,
+            "properties": [
+                {
+                    "name": "text",
+                    "dataType": ["text"],
+                }
+            ],
+        }
+
+    def _json_serializable(self, value: Any) -> Any:
+        if isinstance(value, datetime.datetime):
+            return value.isoformat()
+        return value

+ 166 - 0
api/core/rag/extractor/blod/blod.py

@@ -0,0 +1,166 @@
+"""Schema for Blobs and Blob Loaders.
+
+The goal is to facilitate decoupling of content loading from content parsing code.
+
+In addition, content loading code should provide a lazy loading interface by default.
+"""
+from __future__ import annotations
+
+import contextlib
+import mimetypes
+from abc import ABC, abstractmethod
+from collections.abc import Generator, Iterable, Mapping
+from io import BufferedReader, BytesIO
+from pathlib import PurePath
+from typing import Any, Optional, Union
+
+from pydantic import BaseModel, root_validator
+
+PathLike = Union[str, PurePath]
+
+
+class Blob(BaseModel):
+    """A blob is used to represent raw data by either reference or value.
+
+    Provides an interface to materialize the blob in different representations, and
+    help to decouple the development of data loaders from the downstream parsing of
+    the raw data.
+
+    Inspired by: https://developer.mozilla.org/en-US/docs/Web/API/Blob
+    """
+
+    data: Union[bytes, str, None]  # Raw data
+    mimetype: Optional[str] = None  # Not to be confused with a file extension
+    encoding: str = "utf-8"  # Use utf-8 as default encoding, if decoding to string
+    # Location where the original content was found
+    # Represent location on the local file system
+    # Useful for situations where downstream code assumes it must work with file paths
+    # rather than in-memory content.
+    path: Optional[PathLike] = None
+
+    class Config:
+        arbitrary_types_allowed = True
+        frozen = True
+
+    @property
+    def source(self) -> Optional[str]:
+        """The source location of the blob as string if known otherwise none."""
+        return str(self.path) if self.path else None
+
+    @root_validator(pre=True)
+    def check_blob_is_valid(cls, values: Mapping[str, Any]) -> Mapping[str, Any]:
+        """Verify that either data or path is provided."""
+        if "data" not in values and "path" not in values:
+            raise ValueError("Either data or path must be provided")
+        return values
+
+    def as_string(self) -> str:
+        """Read data as a string."""
+        if self.data is None and self.path:
+            with open(str(self.path), encoding=self.encoding) as f:
+                return f.read()
+        elif isinstance(self.data, bytes):
+            return self.data.decode(self.encoding)
+        elif isinstance(self.data, str):
+            return self.data
+        else:
+            raise ValueError(f"Unable to get string for blob {self}")
+
+    def as_bytes(self) -> bytes:
+        """Read data as bytes."""
+        if isinstance(self.data, bytes):
+            return self.data
+        elif isinstance(self.data, str):
+            return self.data.encode(self.encoding)
+        elif self.data is None and self.path:
+            with open(str(self.path), "rb") as f:
+                return f.read()
+        else:
+            raise ValueError(f"Unable to get bytes for blob {self}")
+
+    @contextlib.contextmanager
+    def as_bytes_io(self) -> Generator[Union[BytesIO, BufferedReader], None, None]:
+        """Read data as a byte stream."""
+        if isinstance(self.data, bytes):
+            yield BytesIO(self.data)
+        elif self.data is None and self.path:
+            with open(str(self.path), "rb") as f:
+                yield f
+        else:
+            raise NotImplementedError(f"Unable to convert blob {self}")
+
+    @classmethod
+    def from_path(
+        cls,
+        path: PathLike,
+        *,
+        encoding: str = "utf-8",
+        mime_type: Optional[str] = None,
+        guess_type: bool = True,
+    ) -> Blob:
+        """Load the blob from a path like object.
+
+        Args:
+            path: path like object to file to be read
+            encoding: Encoding to use if decoding the bytes into a string
+            mime_type: if provided, will be set as the mime-type of the data
+            guess_type: If True, the mimetype will be guessed from the file extension,
+                        if a mime-type was not provided
+
+        Returns:
+            Blob instance
+        """
+        if mime_type is None and guess_type:
+            _mimetype = mimetypes.guess_type(path)[0] if guess_type else None
+        else:
+            _mimetype = mime_type
+        # We do not load the data immediately, instead we treat the blob as a
+        # reference to the underlying data.
+        return cls(data=None, mimetype=_mimetype, encoding=encoding, path=path)
+
+    @classmethod
+    def from_data(
+        cls,
+        data: Union[str, bytes],
+        *,
+        encoding: str = "utf-8",
+        mime_type: Optional[str] = None,
+        path: Optional[str] = None,
+    ) -> Blob:
+        """Initialize the blob from in-memory data.
+
+        Args:
+            data: the in-memory data associated with the blob
+            encoding: Encoding to use if decoding the bytes into a string
+            mime_type: if provided, will be set as the mime-type of the data
+            path: if provided, will be set as the source from which the data came
+
+        Returns:
+            Blob instance
+        """
+        return cls(data=data, mimetype=mime_type, encoding=encoding, path=path)
+
+    def __repr__(self) -> str:
+        """Define the blob representation."""
+        str_repr = f"Blob {id(self)}"
+        if self.source:
+            str_repr += f" {self.source}"
+        return str_repr
+
+
+class BlobLoader(ABC):
+    """Abstract interface for blob loaders implementation.
+
+    Implementer should be able to load raw content from a datasource system according
+    to some criteria and return the raw content lazily as a stream of blobs.
+    """
+
+    @abstractmethod
+    def yield_blobs(
+        self,
+    ) -> Iterable[Blob]:
+        """A lazy loader for raw data represented by LangChain's Blob object.
+
+        Returns:
+            A generator over blobs
+        """

+ 71 - 0
api/core/rag/extractor/csv_extractor.py

@@ -0,0 +1,71 @@
+"""Abstract interface for document loader implementations."""
+import csv
+from typing import Optional
+
+from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.models.document import Document
+
+
+class CSVExtractor(BaseExtractor):
+    """Load CSV files.
+
+
+    Args:
+        file_path: Path to the file to load.
+    """
+
+    def __init__(
+            self,
+            file_path: str,
+            encoding: Optional[str] = None,
+            autodetect_encoding: bool = False,
+            source_column: Optional[str] = None,
+            csv_args: Optional[dict] = None,
+    ):
+        """Initialize with file path."""
+        self._file_path = file_path
+        self._encoding = encoding
+        self._autodetect_encoding = autodetect_encoding
+        self.source_column = source_column
+        self.csv_args = csv_args or {}
+
+    def extract(self) -> list[Document]:
+        """Load data into document objects."""
+        try:
+            with open(self._file_path, newline="", encoding=self._encoding) as csvfile:
+                docs = self._read_from_file(csvfile)
+        except UnicodeDecodeError as e:
+            if self._autodetect_encoding:
+                detected_encodings = detect_filze_encodings(self._file_path)
+                for encoding in detected_encodings:
+                    try:
+                        with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile:
+                            docs = self._read_from_file(csvfile)
+                        break
+                    except UnicodeDecodeError:
+                        continue
+            else:
+                raise RuntimeError(f"Error loading {self._file_path}") from e
+
+        return docs
+
+    def _read_from_file(self, csvfile) -> list[Document]:
+        docs = []
+        csv_reader = csv.DictReader(csvfile, **self.csv_args)  # type: ignore
+        for i, row in enumerate(csv_reader):
+            content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
+            try:
+                source = (
+                    row[self.source_column]
+                    if self.source_column is not None
+                    else ''
+                )
+            except KeyError:
+                raise ValueError(
+                    f"Source column '{self.source_column}' not found in CSV file."
+                )
+            metadata = {"source": source, "row": i}
+            doc = Document(page_content=content, metadata=metadata)
+            docs.append(doc)
+
+        return docs

+ 6 - 0
api/core/rag/extractor/entity/datasource_type.py

@@ -0,0 +1,6 @@
+from enum import Enum
+
+
+class DatasourceType(Enum):
+    FILE = "upload_file"
+    NOTION = "notion_import"

+ 36 - 0
api/core/rag/extractor/entity/extract_setting.py

@@ -0,0 +1,36 @@
+from pydantic import BaseModel
+
+from models.dataset import Document
+from models.model import UploadFile
+
+
+class NotionInfo(BaseModel):
+    """
+    Notion import info.
+    """
+    notion_workspace_id: str
+    notion_obj_id: str
+    notion_page_type: str
+    document: Document = None
+
+    class Config:
+        arbitrary_types_allowed = True
+
+    def __init__(self, **data) -> None:
+        super().__init__(**data)
+
+
+class ExtractSetting(BaseModel):
+    """
+    Model class for provider response.
+    """
+    datasource_type: str
+    upload_file: UploadFile = None
+    notion_info: NotionInfo = None
+    document_model: str = None
+
+    class Config:
+        arbitrary_types_allowed = True
+
+    def __init__(self, **data) -> None:
+        super().__init__(**data)

+ 14 - 9
api/core/data_loader/loader/excel.py → api/core/rag/extractor/excel_extractor.py

@@ -1,14 +1,14 @@
-import logging
+"""Abstract interface for document loader implementations."""
+from typing import Optional
 
-from langchain.document_loaders.base import BaseLoader
-from langchain.schema import Document
 from openpyxl.reader.excel import load_workbook
 
-logger = logging.getLogger(__name__)
+from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.models.document import Document
 
 
-class ExcelLoader(BaseLoader):
-    """Load xlxs files.
+class ExcelExtractor(BaseExtractor):
+    """Load Excel files.
 
 
     Args:
@@ -16,13 +16,18 @@ class ExcelLoader(BaseLoader):
     """
 
     def __init__(
-        self,
-        file_path: str
+            self,
+            file_path: str,
+            encoding: Optional[str] = None,
+            autodetect_encoding: bool = False
     ):
         """Initialize with file path."""
         self._file_path = file_path
+        self._encoding = encoding
+        self._autodetect_encoding = autodetect_encoding
 
-    def load(self) -> list[Document]:
+    def extract(self) -> list[Document]:
+        """Load from file path."""
         data = []
         keys = []
         wb = load_workbook(filename=self._file_path, read_only=True)

+ 139 - 0
api/core/rag/extractor/extract_processor.py

@@ -0,0 +1,139 @@
+import tempfile
+from pathlib import Path
+from typing import Union
+
+import requests
+from flask import current_app
+
+from core.rag.extractor.csv_extractor import CSVExtractor
+from core.rag.extractor.entity.datasource_type import DatasourceType
+from core.rag.extractor.entity.extract_setting import ExtractSetting
+from core.rag.extractor.excel_extractor import ExcelExtractor
+from core.rag.extractor.html_extractor import HtmlExtractor
+from core.rag.extractor.markdown_extractor import MarkdownExtractor
+from core.rag.extractor.notion_extractor import NotionExtractor
+from core.rag.extractor.pdf_extractor import PdfExtractor
+from core.rag.extractor.text_extractor import TextExtractor
+from core.rag.extractor.unstructured.unstructured_doc_extractor import UnstructuredWordExtractor
+from core.rag.extractor.unstructured.unstructured_eml_extractor import UnstructuredEmailExtractor
+from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor
+from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor
+from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor
+from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor
+from core.rag.extractor.unstructured.unstructured_text_extractor import UnstructuredTextExtractor
+from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor
+from core.rag.extractor.word_extractor import WordExtractor
+from core.rag.models.document import Document
+from extensions.ext_storage import storage
+from models.model import UploadFile
+
+SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain']
+USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
+
+
+class ExtractProcessor:
+    @classmethod
+    def load_from_upload_file(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) \
+            -> Union[list[Document], str]:
+        extract_setting = ExtractSetting(
+            datasource_type="upload_file",
+            upload_file=upload_file,
+            document_model='text_model'
+        )
+        if return_text:
+            delimiter = '\n'
+            return delimiter.join([document.page_content for document in cls.extract(extract_setting, is_automatic)])
+        else:
+            return cls.extract(extract_setting, is_automatic)
+
+    @classmethod
+    def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]:
+        response = requests.get(url, headers={
+            "User-Agent": USER_AGENT
+        })
+
+        with tempfile.TemporaryDirectory() as temp_dir:
+            suffix = Path(url).suffix
+            file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
+            with open(file_path, 'wb') as file:
+                file.write(response.content)
+            extract_setting = ExtractSetting(
+                datasource_type="upload_file",
+                document_model='text_model'
+            )
+            if return_text:
+                delimiter = '\n'
+                return delimiter.join([document.page_content for document in cls.extract(
+                    extract_setting=extract_setting, file_path=file_path)])
+            else:
+                return cls.extract(extract_setting=extract_setting, file_path=file_path)
+
+    @classmethod
+    def extract(cls, extract_setting: ExtractSetting, is_automatic: bool = False,
+                file_path: str = None) -> list[Document]:
+        if extract_setting.datasource_type == DatasourceType.FILE.value:
+            with tempfile.TemporaryDirectory() as temp_dir:
+                if not file_path:
+                    upload_file: UploadFile = extract_setting.upload_file
+                    suffix = Path(upload_file.key).suffix
+                    file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
+                    storage.download(upload_file.key, file_path)
+                input_file = Path(file_path)
+                file_extension = input_file.suffix.lower()
+                etl_type = current_app.config['ETL_TYPE']
+                unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL']
+                if etl_type == 'Unstructured':
+                    if file_extension == '.xlsx':
+                        extractor = ExcelExtractor(file_path)
+                    elif file_extension == '.pdf':
+                        extractor = PdfExtractor(file_path)
+                    elif file_extension in ['.md', '.markdown']:
+                        extractor = UnstructuredMarkdownExtractor(file_path, unstructured_api_url) if is_automatic \
+                            else MarkdownExtractor(file_path, autodetect_encoding=True)
+                    elif file_extension in ['.htm', '.html']:
+                        extractor = HtmlExtractor(file_path)
+                    elif file_extension in ['.docx']:
+                        extractor = UnstructuredWordExtractor(file_path, unstructured_api_url)
+                    elif file_extension == '.csv':
+                        extractor = CSVExtractor(file_path, autodetect_encoding=True)
+                    elif file_extension == '.msg':
+                        extractor = UnstructuredMsgExtractor(file_path, unstructured_api_url)
+                    elif file_extension == '.eml':
+                        extractor = UnstructuredEmailExtractor(file_path, unstructured_api_url)
+                    elif file_extension == '.ppt':
+                        extractor = UnstructuredPPTExtractor(file_path, unstructured_api_url)
+                    elif file_extension == '.pptx':
+                        extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url)
+                    elif file_extension == '.xml':
+                        extractor = UnstructuredXmlExtractor(file_path, unstructured_api_url)
+                    else:
+                        # txt
+                        extractor = UnstructuredTextExtractor(file_path, unstructured_api_url) if is_automatic \
+                            else TextExtractor(file_path, autodetect_encoding=True)
+                else:
+                    if file_extension == '.xlsx':
+                        extractor = ExcelExtractor(file_path)
+                    elif file_extension == '.pdf':
+                        extractor = PdfExtractor(file_path)
+                    elif file_extension in ['.md', '.markdown']:
+                        extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
+                    elif file_extension in ['.htm', '.html']:
+                        extractor = HtmlExtractor(file_path)
+                    elif file_extension in ['.docx']:
+                        extractor = WordExtractor(file_path)
+                    elif file_extension == '.csv':
+                        extractor = CSVExtractor(file_path, autodetect_encoding=True)
+                    else:
+                        # txt
+                        extractor = TextExtractor(file_path, autodetect_encoding=True)
+                return extractor.extract()
+        elif extract_setting.datasource_type == DatasourceType.NOTION.value:
+            extractor = NotionExtractor(
+                notion_workspace_id=extract_setting.notion_info.notion_workspace_id,
+                notion_obj_id=extract_setting.notion_info.notion_obj_id,
+                notion_page_type=extract_setting.notion_info.notion_page_type,
+                document_model=extract_setting.notion_info.document
+            )
+            return extractor.extract()
+        else:
+            raise ValueError(f"Unsupported datasource type: {extract_setting.datasource_type}")

+ 12 - 0
api/core/rag/extractor/extractor_base.py

@@ -0,0 +1,12 @@
+"""Abstract interface for document loader implementations."""
+from abc import ABC, abstractmethod
+
+
+class BaseExtractor(ABC):
+    """Interface for extract files.
+    """
+
+    @abstractmethod
+    def extract(self):
+        raise NotImplementedError
+

+ 46 - 0
api/core/rag/extractor/helpers.py

@@ -0,0 +1,46 @@
+"""Document loader helpers."""
+
+import concurrent.futures
+from typing import NamedTuple, Optional, cast
+
+
+class FileEncoding(NamedTuple):
+    """A file encoding as the NamedTuple."""
+
+    encoding: Optional[str]
+    """The encoding of the file."""
+    confidence: float
+    """The confidence of the encoding."""
+    language: Optional[str]
+    """The language of the file."""
+
+
+def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding]:
+    """Try to detect the file encoding.
+
+    Returns a list of `FileEncoding` tuples with the detected encodings ordered
+    by confidence.
+
+    Args:
+        file_path: The path to the file to detect the encoding for.
+        timeout: The timeout in seconds for the encoding detection.
+    """
+    import chardet
+
+    def read_and_detect(file_path: str) -> list[dict]:
+        with open(file_path, "rb") as f:
+            rawdata = f.read()
+        return cast(list[dict], chardet.detect_all(rawdata))
+
+    with concurrent.futures.ThreadPoolExecutor() as executor:
+        future = executor.submit(read_and_detect, file_path)
+        try:
+            encodings = future.result(timeout=timeout)
+        except concurrent.futures.TimeoutError:
+            raise TimeoutError(
+                f"Timeout reached while detecting encoding for {file_path}"
+            )
+
+    if all(encoding["encoding"] is None for encoding in encodings):
+        raise RuntimeError(f"Could not detect encoding for {file_path}")
+    return [FileEncoding(**enc) for enc in encodings if enc["encoding"] is not None]

+ 24 - 20
api/core/data_loader/loader/csv_loader.py → api/core/rag/extractor/html_extractor.py

@@ -1,51 +1,55 @@
-import csv
-import logging
+"""Abstract interface for document loader implementations."""
 from typing import Optional
 
-from langchain.document_loaders import CSVLoader as LCCSVLoader
-from langchain.document_loaders.helpers import detect_file_encodings
-from langchain.schema import Document
+from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.extractor.helpers import detect_file_encodings
+from core.rag.models.document import Document
 
-logger = logging.getLogger(__name__)
 
+class HtmlExtractor(BaseExtractor):
+    """Load html files.
+
+
+    Args:
+        file_path: Path to the file to load.
+    """
 
-class CSVLoader(LCCSVLoader):
     def __init__(
             self,
             file_path: str,
+            encoding: Optional[str] = None,
+            autodetect_encoding: bool = False,
             source_column: Optional[str] = None,
             csv_args: Optional[dict] = None,
-            encoding: Optional[str] = None,
-            autodetect_encoding: bool = True,
     ):
-        self.file_path = file_path
+        """Initialize with file path."""
+        self._file_path = file_path
+        self._encoding = encoding
+        self._autodetect_encoding = autodetect_encoding
         self.source_column = source_column
-        self.encoding = encoding
         self.csv_args = csv_args or {}
-        self.autodetect_encoding = autodetect_encoding
 
-    def load(self) -> list[Document]:
+    def extract(self) -> list[Document]:
         """Load data into document objects."""
         try:
-            with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
+            with open(self._file_path, newline="", encoding=self._encoding) as csvfile:
                 docs = self._read_from_file(csvfile)
         except UnicodeDecodeError as e:
-            if self.autodetect_encoding:
-                detected_encodings = detect_file_encodings(self.file_path)
+            if self._autodetect_encoding:
+                detected_encodings = detect_file_encodings(self._file_path)
                 for encoding in detected_encodings:
-                    logger.debug("Trying encoding: ", encoding.encoding)
                     try:
-                        with open(self.file_path, newline="", encoding=encoding.encoding) as csvfile:
+                        with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile:
                             docs = self._read_from_file(csvfile)
                         break
                     except UnicodeDecodeError:
                         continue
             else:
-                raise RuntimeError(f"Error loading {self.file_path}") from e
+                raise RuntimeError(f"Error loading {self._file_path}") from e
 
         return docs
 
-    def _read_from_file(self, csvfile):
+    def _read_from_file(self, csvfile) -> list[Document]:
         docs = []
         csv_reader = csv.DictReader(csvfile, **self.csv_args)  # type: ignore
         for i, row in enumerate(csv_reader):

+ 14 - 26
api/core/data_loader/loader/markdown.py → api/core/rag/extractor/markdown_extractor.py

@@ -1,39 +1,27 @@
-import logging
+"""Abstract interface for document loader implementations."""
 import re
 from typing import Optional, cast
 
-from langchain.document_loaders.base import BaseLoader
-from langchain.document_loaders.helpers import detect_file_encodings
-from langchain.schema import Document
+from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.extractor.helpers import detect_file_encodings
+from core.rag.models.document import Document
 
-logger = logging.getLogger(__name__)
 
-
-class MarkdownLoader(BaseLoader):
-    """Load md files.
+class MarkdownExtractor(BaseExtractor):
+    """Load Markdown files.
 
 
     Args:
         file_path: Path to the file to load.
-
-        remove_hyperlinks: Whether to remove hyperlinks from the text.
-
-        remove_images: Whether to remove images from the text.
-
-        encoding: File encoding to use. If `None`, the file will be loaded
-        with the default system encoding.
-
-        autodetect_encoding: Whether to try to autodetect the file encoding
-            if the specified encoding fails.
     """
 
     def __init__(
-        self,
-        file_path: str,
-        remove_hyperlinks: bool = True,
-        remove_images: bool = True,
-        encoding: Optional[str] = None,
-        autodetect_encoding: bool = True,
+            self,
+            file_path: str,
+            remove_hyperlinks: bool = True,
+            remove_images: bool = True,
+            encoding: Optional[str] = None,
+            autodetect_encoding: bool = True,
     ):
         """Initialize with file path."""
         self._file_path = file_path
@@ -42,7 +30,8 @@ class MarkdownLoader(BaseLoader):
         self._encoding = encoding
         self._autodetect_encoding = autodetect_encoding
 
-    def load(self) -> list[Document]:
+    def extract(self) -> list[Document]:
+        """Load from file path."""
         tups = self.parse_tups(self._file_path)
         documents = []
         for header, value in tups:
@@ -113,7 +102,6 @@ class MarkdownLoader(BaseLoader):
             if self._autodetect_encoding:
                 detected_encodings = detect_file_encodings(filepath)
                 for encoding in detected_encodings:
-                    logger.debug("Trying encoding: ", encoding.encoding)
                     try:
                         with open(filepath, encoding=encoding.encoding) as f:
                             content = f.read()

+ 23 - 37
api/core/data_loader/loader/notion.py → api/core/rag/extractor/notion_extractor.py

@@ -4,9 +4,10 @@ from typing import Any, Optional
 
 import requests
 from flask import current_app
-from langchain.document_loaders.base import BaseLoader
+from flask_login import current_user
 from langchain.schema import Document
 
+from core.rag.extractor.extractor_base import BaseExtractor
 from extensions.ext_database import db
 from models.dataset import Document as DocumentModel
 from models.source import DataSourceBinding
@@ -22,52 +23,37 @@ RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
 HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']
 
 
-class NotionLoader(BaseLoader):
+class NotionExtractor(BaseExtractor):
+
     def __init__(
             self,
-            notion_access_token: str,
             notion_workspace_id: str,
             notion_obj_id: str,
             notion_page_type: str,
-            document_model: Optional[DocumentModel] = None
+            document_model: Optional[DocumentModel] = None,
+            notion_access_token: Optional[str] = None
     ):
+        self._notion_access_token = None
         self._document_model = document_model
         self._notion_workspace_id = notion_workspace_id
         self._notion_obj_id = notion_obj_id
         self._notion_page_type = notion_page_type
-        self._notion_access_token = notion_access_token
-
-        if not self._notion_access_token:
-            integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN')
-            if integration_token is None:
-                raise ValueError(
-                    "Must specify `integration_token` or set environment "
-                    "variable `NOTION_INTEGRATION_TOKEN`."
-                )
-
-            self._notion_access_token = integration_token
-
-    @classmethod
-    def from_document(cls, document_model: DocumentModel):
-        data_source_info = document_model.data_source_info_dict
-        if not data_source_info or 'notion_page_id' not in data_source_info \
-                or 'notion_workspace_id' not in data_source_info:
-            raise ValueError("no notion page found")
-
-        notion_workspace_id = data_source_info['notion_workspace_id']
-        notion_obj_id = data_source_info['notion_page_id']
-        notion_page_type = data_source_info['type']
-        notion_access_token = cls._get_access_token(document_model.tenant_id, notion_workspace_id)
-
-        return cls(
-            notion_access_token=notion_access_token,
-            notion_workspace_id=notion_workspace_id,
-            notion_obj_id=notion_obj_id,
-            notion_page_type=notion_page_type,
-            document_model=document_model
-        )
-
-    def load(self) -> list[Document]:
+        if notion_access_token:
+            self._notion_access_token = notion_access_token
+        else:
+            self._notion_access_token = self._get_access_token(current_user.current_tenant_id,
+                                                               self._notion_workspace_id)
+            if not self._notion_access_token:
+                integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN')
+                if integration_token is None:
+                    raise ValueError(
+                        "Must specify `integration_token` or set environment "
+                        "variable `NOTION_INTEGRATION_TOKEN`."
+                    )
+
+                self._notion_access_token = integration_token
+
+    def extract(self) -> list[Document]:
         self.update_last_edited_time(
             self._document_model
         )

+ 72 - 0
api/core/rag/extractor/pdf_extractor.py

@@ -0,0 +1,72 @@
+"""Abstract interface for document loader implementations."""
+from collections.abc import Iterator
+from typing import Optional
+
+from core.rag.extractor.blod.blod import Blob
+from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.models.document import Document
+from extensions.ext_storage import storage
+
+
+class PdfExtractor(BaseExtractor):
+    """Load pdf files.
+
+
+    Args:
+        file_path: Path to the file to load.
+    """
+
+    def __init__(
+            self,
+            file_path: str,
+            file_cache_key: Optional[str] = None
+    ):
+        """Initialize with file path."""
+        self._file_path = file_path
+        self._file_cache_key = file_cache_key
+
+    def extract(self) -> list[Document]:
+        plaintext_file_key = ''
+        plaintext_file_exists = False
+        if self._file_cache_key:
+            try:
+                text = storage.load(self._file_cache_key).decode('utf-8')
+                plaintext_file_exists = True
+                return [Document(page_content=text)]
+            except FileNotFoundError:
+                pass
+        documents = list(self.load())
+        text_list = []
+        for document in documents:
+            text_list.append(document.page_content)
+        text = "\n\n".join(text_list)
+
+        # save plaintext file for caching
+        if not plaintext_file_exists and plaintext_file_key:
+            storage.save(plaintext_file_key, text.encode('utf-8'))
+
+        return documents
+
+    def load(
+            self,
+    ) -> Iterator[Document]:
+        """Lazy load given path as pages."""
+        blob = Blob.from_path(self._file_path)
+        yield from self.parse(blob)
+
+    def parse(self, blob: Blob) -> Iterator[Document]:
+        """Lazily parse the blob."""
+        import pypdfium2
+
+        with blob.as_bytes_io() as file_path:
+            pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
+            try:
+                for page_number, page in enumerate(pdf_reader):
+                    text_page = page.get_textpage()
+                    content = text_page.get_text_range()
+                    text_page.close()
+                    page.close()
+                    metadata = {"source": blob.source, "page": page_number}
+                    yield Document(page_content=content, metadata=metadata)
+            finally:
+                pdf_reader.close()

+ 50 - 0
api/core/rag/extractor/text_extractor.py

@@ -0,0 +1,50 @@
+"""Abstract interface for document loader implementations."""
+from typing import Optional
+
+from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.extractor.helpers import detect_file_encodings
+from core.rag.models.document import Document
+
+
+class TextExtractor(BaseExtractor):
+    """Load text files.
+
+
+    Args:
+        file_path: Path to the file to load.
+    """
+
+    def __init__(
+            self,
+            file_path: str,
+            encoding: Optional[str] = None,
+            autodetect_encoding: bool = False
+    ):
+        """Initialize with file path."""
+        self._file_path = file_path
+        self._encoding = encoding
+        self._autodetect_encoding = autodetect_encoding
+
+    def extract(self) -> list[Document]:
+        """Load from file path."""
+        text = ""
+        try:
+            with open(self._file_path, encoding=self._encoding) as f:
+                text = f.read()
+        except UnicodeDecodeError as e:
+            if self._autodetect_encoding:
+                detected_encodings = detect_file_encodings(self._file_path)
+                for encoding in detected_encodings:
+                    try:
+                        with open(self._file_path, encoding=encoding.encoding) as f:
+                            text = f.read()
+                        break
+                    except UnicodeDecodeError:
+                        continue
+            else:
+                raise RuntimeError(f"Error loading {self._file_path}") from e
+        except Exception as e:
+            raise RuntimeError(f"Error loading {self._file_path}") from e
+
+        metadata = {"source": self._file_path}
+        return [Document(page_content=text, metadata=metadata)]

+ 61 - 0
api/core/rag/extractor/unstructured/unstructured_doc_extractor.py

@@ -0,0 +1,61 @@
+import logging
+import os
+
+from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.models.document import Document
+
+logger = logging.getLogger(__name__)
+
+
+class UnstructuredWordExtractor(BaseExtractor):
+    """Loader that uses unstructured to load word documents.
+    """
+
+    def __init__(
+            self,
+            file_path: str,
+            api_url: str,
+    ):
+        """Initialize with file path."""
+        self._file_path = file_path
+        self._api_url = api_url
+
+    def extract(self) -> list[Document]:
+        from unstructured.__version__ import __version__ as __unstructured_version__
+        from unstructured.file_utils.filetype import FileType, detect_filetype
+
+        unstructured_version = tuple(
+            [int(x) for x in __unstructured_version__.split(".")]
+        )
+        # check the file extension
+        try:
+            import magic  # noqa: F401
+
+            is_doc = detect_filetype(self._file_path) == FileType.DOC
+        except ImportError:
+            _, extension = os.path.splitext(str(self._file_path))
+            is_doc = extension == ".doc"
+
+        if is_doc and unstructured_version < (0, 4, 11):
+            raise ValueError(
+                f"You are on unstructured version {__unstructured_version__}. "
+                "Partitioning .doc files is only supported in unstructured>=0.4.11. "
+                "Please upgrade the unstructured package and try again."
+            )
+
+        if is_doc:
+            from unstructured.partition.doc import partition_doc
+
+            elements = partition_doc(filename=self._file_path)
+        else:
+            from unstructured.partition.docx import partition_docx
+
+            elements = partition_docx(filename=self._file_path)
+
+        from unstructured.chunking.title import chunk_by_title
+        chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
+        documents = []
+        for chunk in chunks:
+            text = chunk.text.strip()
+            documents.append(Document(page_content=text))
+        return documents

+ 5 - 4
api/core/data_loader/loader/unstructured/unstructured_eml.py → api/core/rag/extractor/unstructured/unstructured_eml_extractor.py

@@ -2,13 +2,14 @@ import base64
 import logging
 
 from bs4 import BeautifulSoup
-from langchain.document_loaders.base import BaseLoader
-from langchain.schema import Document
+
+from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.models.document import Document
 
 logger = logging.getLogger(__name__)
 
 
-class UnstructuredEmailLoader(BaseLoader):
+class UnstructuredEmailExtractor(BaseExtractor):
     """Load msg files.
     Args:
         file_path: Path to the file to load.
@@ -23,7 +24,7 @@ class UnstructuredEmailLoader(BaseLoader):
         self._file_path = file_path
         self._api_url = api_url
 
-    def load(self) -> list[Document]:
+    def extract(self) -> list[Document]:
         from unstructured.partition.email import partition_email
         elements = partition_email(filename=self._file_path, api_url=self._api_url)
 

+ 4 - 4
api/core/data_loader/loader/unstructured/unstructured_markdown.py → api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py

@@ -1,12 +1,12 @@
 import logging
 
-from langchain.document_loaders.base import BaseLoader
-from langchain.schema import Document
+from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.models.document import Document
 
 logger = logging.getLogger(__name__)
 
 
-class UnstructuredMarkdownLoader(BaseLoader):
+class UnstructuredMarkdownExtractor(BaseExtractor):
     """Load md files.
 
 
@@ -33,7 +33,7 @@ class UnstructuredMarkdownLoader(BaseLoader):
         self._file_path = file_path
         self._api_url = api_url
 
-    def load(self) -> list[Document]:
+    def extract(self) -> list[Document]:
         from unstructured.partition.md import partition_md
 
         elements = partition_md(filename=self._file_path, api_url=self._api_url)

+ 4 - 4
api/core/data_loader/loader/unstructured/unstructured_msg.py → api/core/rag/extractor/unstructured/unstructured_msg_extractor.py

@@ -1,12 +1,12 @@
 import logging
 
-from langchain.document_loaders.base import BaseLoader
-from langchain.schema import Document
+from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.models.document import Document
 
 logger = logging.getLogger(__name__)
 
 
-class UnstructuredMsgLoader(BaseLoader):
+class UnstructuredMsgExtractor(BaseExtractor):
     """Load msg files.
 
 
@@ -23,7 +23,7 @@ class UnstructuredMsgLoader(BaseLoader):
         self._file_path = file_path
         self._api_url = api_url
 
-    def load(self) -> list[Document]:
+    def extract(self) -> list[Document]:
         from unstructured.partition.msg import partition_msg
 
         elements = partition_msg(filename=self._file_path, api_url=self._api_url)

+ 8 - 7
api/core/data_loader/loader/unstructured/unstructured_ppt.py → api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py

@@ -1,11 +1,12 @@
 import logging
 
-from langchain.document_loaders.base import BaseLoader
-from langchain.schema import Document
+from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.models.document import Document
 
 logger = logging.getLogger(__name__)
 
-class UnstructuredPPTLoader(BaseLoader):
+
+class UnstructuredPPTExtractor(BaseExtractor):
     """Load msg files.
 
 
@@ -14,15 +15,15 @@ class UnstructuredPPTLoader(BaseLoader):
     """
 
     def __init__(
-        self,
-        file_path: str,
-        api_url: str
+            self,
+            file_path: str,
+            api_url: str
     ):
         """Initialize with file path."""
         self._file_path = file_path
         self._api_url = api_url
 
-    def load(self) -> list[Document]:
+    def extract(self) -> list[Document]:
         from unstructured.partition.ppt import partition_ppt
 
         elements = partition_ppt(filename=self._file_path, api_url=self._api_url)

+ 9 - 7
api/core/data_loader/loader/unstructured/unstructured_pptx.py → api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py

@@ -1,10 +1,12 @@
 import logging
 
-from langchain.document_loaders.base import BaseLoader
-from langchain.schema import Document
+from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.models.document import Document
 
 logger = logging.getLogger(__name__)
-class UnstructuredPPTXLoader(BaseLoader):
+
+
+class UnstructuredPPTXExtractor(BaseExtractor):
     """Load msg files.
 
 
@@ -13,15 +15,15 @@ class UnstructuredPPTXLoader(BaseLoader):
     """
 
     def __init__(
-        self,
-        file_path: str,
-        api_url: str
+            self,
+            file_path: str,
+            api_url: str
     ):
         """Initialize with file path."""
         self._file_path = file_path
         self._api_url = api_url
 
-    def load(self) -> list[Document]:
+    def extract(self) -> list[Document]:
         from unstructured.partition.pptx import partition_pptx
 
         elements = partition_pptx(filename=self._file_path, api_url=self._api_url)

+ 4 - 4
api/core/data_loader/loader/unstructured/unstructured_text.py → api/core/rag/extractor/unstructured/unstructured_text_extractor.py

@@ -1,12 +1,12 @@
 import logging
 
-from langchain.document_loaders.base import BaseLoader
-from langchain.schema import Document
+from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.models.document import Document
 
 logger = logging.getLogger(__name__)
 
 
-class UnstructuredTextLoader(BaseLoader):
+class UnstructuredTextExtractor(BaseExtractor):
     """Load msg files.
 
 
@@ -23,7 +23,7 @@ class UnstructuredTextLoader(BaseLoader):
         self._file_path = file_path
         self._api_url = api_url
 
-    def load(self) -> list[Document]:
+    def extract(self) -> list[Document]:
         from unstructured.partition.text import partition_text
 
         elements = partition_text(filename=self._file_path, api_url=self._api_url)

+ 4 - 4
api/core/data_loader/loader/unstructured/unstructured_xml.py → api/core/rag/extractor/unstructured/unstructured_xml_extractor.py

@@ -1,12 +1,12 @@
 import logging
 
-from langchain.document_loaders.base import BaseLoader
-from langchain.schema import Document
+from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.models.document import Document
 
 logger = logging.getLogger(__name__)
 
 
-class UnstructuredXmlLoader(BaseLoader):
+class UnstructuredXmlExtractor(BaseExtractor):
     """Load msg files.
 
 
@@ -23,7 +23,7 @@ class UnstructuredXmlLoader(BaseLoader):
         self._file_path = file_path
         self._api_url = api_url
 
-    def load(self) -> list[Document]:
+    def extract(self) -> list[Document]:
         from unstructured.partition.xml import partition_xml
 
         elements = partition_xml(filename=self._file_path, xml_keep_tags=True, api_url=self._api_url)

+ 62 - 0
api/core/rag/extractor/word_extractor.py

@@ -0,0 +1,62 @@
+"""Abstract interface for document loader implementations."""
+import os
+import tempfile
+from urllib.parse import urlparse
+
+import requests
+
+from core.rag.extractor.extractor_base import BaseExtractor
+from core.rag.models.document import Document
+
+
+class WordExtractor(BaseExtractor):
+    """Load pdf files.
+
+
+    Args:
+        file_path: Path to the file to load.
+    """
+
+    def __init__(self, file_path: str):
+        """Initialize with file path."""
+        self.file_path = file_path
+        if "~" in self.file_path:
+            self.file_path = os.path.expanduser(self.file_path)
+
+        # If the file is a web path, download it to a temporary file, and use that
+        if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path):
+            r = requests.get(self.file_path)
+
+            if r.status_code != 200:
+                raise ValueError(
+                    "Check the url of your file; returned status code %s"
+                    % r.status_code
+                )
+
+            self.web_path = self.file_path
+            self.temp_file = tempfile.NamedTemporaryFile()
+            self.temp_file.write(r.content)
+            self.file_path = self.temp_file.name
+        elif not os.path.isfile(self.file_path):
+            raise ValueError("File path %s is not a valid file or url" % self.file_path)
+
+    def __del__(self) -> None:
+        if hasattr(self, "temp_file"):
+            self.temp_file.close()
+
+    def extract(self) -> list[Document]:
+        """Load given path as single page."""
+        import docx2txt
+
+        return [
+            Document(
+                page_content=docx2txt.process(self.file_path),
+                metadata={"source": self.file_path},
+            )
+        ]
+
+    @staticmethod
+    def _is_valid_url(url: str) -> bool:
+        """Check if the url is valid."""
+        parsed = urlparse(url)
+        return bool(parsed.netloc) and bool(parsed.scheme)

+ 0 - 0
api/core/rag/index_processor/__init__.py


+ 0 - 0
api/core/rag/index_processor/constant/__init__.py


+ 8 - 0
api/core/rag/index_processor/constant/index_type.py

@@ -0,0 +1,8 @@
+from enum import Enum
+
+
+class IndexType(Enum):
+    PARAGRAPH_INDEX = "text_model"
+    QA_INDEX = "qa_model"
+    PARENT_CHILD_INDEX = "parent_child_index"
+    SUMMARY_INDEX = "summary_index"

+ 70 - 0
api/core/rag/index_processor/index_processor_base.py

@@ -0,0 +1,70 @@
+"""Abstract interface for document loader implementations."""
+from abc import ABC, abstractmethod
+from typing import Optional
+
+from langchain.text_splitter import TextSplitter
+
+from core.model_manager import ModelInstance
+from core.rag.extractor.entity.extract_setting import ExtractSetting
+from core.rag.models.document import Document
+from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter
+from models.dataset import Dataset, DatasetProcessRule
+
+
+class BaseIndexProcessor(ABC):
+    """Interface for extract files.
+    """
+
+    @abstractmethod
+    def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
+        raise NotImplementedError
+
+    @abstractmethod
+    def transform(self, documents: list[Document], **kwargs) -> list[Document]:
+        raise NotImplementedError
+
+    @abstractmethod
+    def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
+        raise NotImplementedError
+
+    def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
+        raise NotImplementedError
+
+    @abstractmethod
+    def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
+                 score_threshold: float, reranking_model: dict) -> list[Document]:
+        raise NotImplementedError
+
+    def _get_splitter(self, processing_rule: dict,
+                      embedding_model_instance: Optional[ModelInstance]) -> TextSplitter:
+        """
+        Get the NodeParser object according to the processing rule.
+        """
+        if processing_rule['mode'] == "custom":
+            # The user-defined segmentation rule
+            rules = processing_rule['rules']
+            segmentation = rules["segmentation"]
+            if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > 1000:
+                raise ValueError("Custom segment length should be between 50 and 1000.")
+
+            separator = segmentation["separator"]
+            if separator:
+                separator = separator.replace('\\n', '\n')
+
+            character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
+                chunk_size=segmentation["max_tokens"],
+                chunk_overlap=0,
+                fixed_separator=separator,
+                separators=["\n\n", "。", ".", " ", ""],
+                embedding_model_instance=embedding_model_instance
+            )
+        else:
+            # Automatic segmentation
+            character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
+                chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
+                chunk_overlap=0,
+                separators=["\n\n", "。", ".", " ", ""],
+                embedding_model_instance=embedding_model_instance
+            )
+
+        return character_splitter

+ 28 - 0
api/core/rag/index_processor/index_processor_factory.py

@@ -0,0 +1,28 @@
+"""Abstract interface for document loader implementations."""
+
+from core.rag.index_processor.constant.index_type import IndexType
+from core.rag.index_processor.index_processor_base import BaseIndexProcessor
+from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
+from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor
+
+
+class IndexProcessorFactory:
+    """IndexProcessorInit.
+    """
+
+    def __init__(self, index_type: str):
+        self._index_type = index_type
+
+    def init_index_processor(self) -> BaseIndexProcessor:
+        """Init index processor."""
+
+        if not self._index_type:
+            raise ValueError("Index type must be specified.")
+
+        if self._index_type == IndexType.PARAGRAPH_INDEX.value:
+            return ParagraphIndexProcessor()
+        elif self._index_type == IndexType.QA_INDEX.value:
+
+            return QAIndexProcessor()
+        else:
+            raise ValueError(f"Index type {self._index_type} is not supported.")

+ 0 - 0
api/core/rag/index_processor/processor/__init__.py


+ 92 - 0
api/core/rag/index_processor/processor/paragraph_index_processor.py

@@ -0,0 +1,92 @@
+"""Paragraph index processor."""
+import uuid
+from typing import Optional
+
+from core.rag.cleaner.clean_processor import CleanProcessor
+from core.rag.datasource.keyword.keyword_factory import Keyword
+from core.rag.datasource.retrieval_service import RetrievalService
+from core.rag.datasource.vdb.vector_factory import Vector
+from core.rag.extractor.entity.extract_setting import ExtractSetting
+from core.rag.extractor.extract_processor import ExtractProcessor
+from core.rag.index_processor.index_processor_base import BaseIndexProcessor
+from core.rag.models.document import Document
+from libs import helper
+from models.dataset import Dataset
+
+
+class ParagraphIndexProcessor(BaseIndexProcessor):
+
+    def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
+
+        text_docs = ExtractProcessor.extract(extract_setting=extract_setting,
+                                             is_automatic=kwargs.get('process_rule_mode') == "automatic")
+
+        return text_docs
+
+    def transform(self, documents: list[Document], **kwargs) -> list[Document]:
+        # Split the text documents into nodes.
+        splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'),
+                                      embedding_model_instance=kwargs.get('embedding_model_instance'))
+        all_documents = []
+        for document in documents:
+            # document clean
+            document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule'))
+            document.page_content = document_text
+            # parse document to nodes
+            document_nodes = splitter.split_documents([document])
+            split_documents = []
+            for document_node in document_nodes:
+
+                if document_node.page_content.strip():
+                    doc_id = str(uuid.uuid4())
+                    hash = helper.generate_text_hash(document_node.page_content)
+                    document_node.metadata['doc_id'] = doc_id
+                    document_node.metadata['doc_hash'] = hash
+                    # delete Spliter character
+                    page_content = document_node.page_content
+                    if page_content.startswith(".") or page_content.startswith("。"):
+                        page_content = page_content[1:]
+                    else:
+                        page_content = page_content
+                    document_node.page_content = page_content
+                    split_documents.append(document_node)
+            all_documents.extend(split_documents)
+        return all_documents
+
+    def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
+        if dataset.indexing_technique == 'high_quality':
+            vector = Vector(dataset)
+            vector.create(documents)
+        if with_keywords:
+            keyword = Keyword(dataset)
+            keyword.create(documents)
+
+    def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
+        if dataset.indexing_technique == 'high_quality':
+            vector = Vector(dataset)
+            if node_ids:
+                vector.delete_by_ids(node_ids)
+            else:
+                vector.delete()
+        if with_keywords:
+            keyword = Keyword(dataset)
+            if node_ids:
+                keyword.delete_by_ids(node_ids)
+            else:
+                keyword.delete()
+
+    def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
+                 score_threshold: float, reranking_model: dict) -> list[Document]:
+        # Set search parameters.
+        results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query,
+                                            top_k=top_k, score_threshold=score_threshold,
+                                            reranking_model=reranking_model)
+        # Organize results.
+        docs = []
+        for result in results:
+            metadata = result.metadata
+            metadata['score'] = result.score
+            if result.score > score_threshold:
+                doc = Document(page_content=result.page_content, metadata=metadata)
+                docs.append(doc)
+        return docs

+ 161 - 0
api/core/rag/index_processor/processor/qa_index_processor.py

@@ -0,0 +1,161 @@
+"""Paragraph index processor."""
+import logging
+import re
+import threading
+import uuid
+from typing import Optional
+
+import pandas as pd
+from flask import Flask, current_app
+from flask_login import current_user
+from werkzeug.datastructures import FileStorage
+
+from core.generator.llm_generator import LLMGenerator
+from core.rag.cleaner.clean_processor import CleanProcessor
+from core.rag.datasource.retrieval_service import RetrievalService
+from core.rag.datasource.vdb.vector_factory import Vector
+from core.rag.extractor.entity.extract_setting import ExtractSetting
+from core.rag.extractor.extract_processor import ExtractProcessor
+from core.rag.index_processor.index_processor_base import BaseIndexProcessor
+from core.rag.models.document import Document
+from libs import helper
+from models.dataset import Dataset
+
+
+class QAIndexProcessor(BaseIndexProcessor):
+    def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
+
+        text_docs = ExtractProcessor.extract(extract_setting=extract_setting,
+                                             is_automatic=kwargs.get('process_rule_mode') == "automatic")
+        return text_docs
+
+    def transform(self, documents: list[Document], **kwargs) -> list[Document]:
+        splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'),
+                                      embedding_model_instance=None)
+
+        # Split the text documents into nodes.
+        all_documents = []
+        all_qa_documents = []
+        for document in documents:
+            # document clean
+            document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule'))
+            document.page_content = document_text
+
+            # parse document to nodes
+            document_nodes = splitter.split_documents([document])
+            split_documents = []
+            for document_node in document_nodes:
+
+                if document_node.page_content.strip():
+                    doc_id = str(uuid.uuid4())
+                    hash = helper.generate_text_hash(document_node.page_content)
+                    document_node.metadata['doc_id'] = doc_id
+                    document_node.metadata['doc_hash'] = hash
+                    # delete Spliter character
+                    page_content = document_node.page_content
+                    if page_content.startswith(".") or page_content.startswith("。"):
+                        page_content = page_content[1:]
+                    else:
+                        page_content = page_content
+                    document_node.page_content = page_content
+                    split_documents.append(document_node)
+            all_documents.extend(split_documents)
+        for i in range(0, len(all_documents), 10):
+            threads = []
+            sub_documents = all_documents[i:i + 10]
+            for doc in sub_documents:
+                document_format_thread = threading.Thread(target=self._format_qa_document, kwargs={
+                    'flask_app': current_app._get_current_object(),
+                    'tenant_id': current_user.current_tenant.id,
+                    'document_node': doc,
+                    'all_qa_documents': all_qa_documents,
+                    'document_language': kwargs.get('document_language', 'English')})
+                threads.append(document_format_thread)
+                document_format_thread.start()
+            for thread in threads:
+                thread.join()
+        return all_qa_documents
+
+    def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]:
+
+        # check file type
+        if not file.filename.endswith('.csv'):
+            raise ValueError("Invalid file type. Only CSV files are allowed")
+
+        try:
+            # Skip the first row
+            df = pd.read_csv(file)
+            text_docs = []
+            for index, row in df.iterrows():
+                data = Document(page_content=row[0], metadata={'answer': row[1]})
+                text_docs.append(data)
+            if len(text_docs) == 0:
+                raise ValueError("The CSV file is empty.")
+
+        except Exception as e:
+            raise ValueError(str(e))
+        return text_docs
+
+    def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
+        if dataset.indexing_technique == 'high_quality':
+            vector = Vector(dataset)
+            vector.create(documents)
+
+    def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
+        vector = Vector(dataset)
+        if node_ids:
+            vector.delete_by_ids(node_ids)
+        else:
+            vector.delete()
+
+    def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
+                 score_threshold: float, reranking_model: dict):
+        # Set search parameters.
+        results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query,
+                                            top_k=top_k, score_threshold=score_threshold,
+                                            reranking_model=reranking_model)
+        # Organize results.
+        docs = []
+        for result in results:
+            metadata = result.metadata
+            metadata['score'] = result.score
+            if result.score > score_threshold:
+                doc = Document(page_content=result.page_content, metadata=metadata)
+                docs.append(doc)
+        return docs
+
+    def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
+        format_documents = []
+        if document_node.page_content is None or not document_node.page_content.strip():
+            return
+        with flask_app.app_context():
+            try:
+                # qa model document
+                response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language)
+                document_qa_list = self._format_split_text(response)
+                qa_documents = []
+                for result in document_qa_list:
+                    qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
+                    doc_id = str(uuid.uuid4())
+                    hash = helper.generate_text_hash(result['question'])
+                    qa_document.metadata['answer'] = result['answer']
+                    qa_document.metadata['doc_id'] = doc_id
+                    qa_document.metadata['doc_hash'] = hash
+                    qa_documents.append(qa_document)
+                format_documents.extend(qa_documents)
+            except Exception as e:
+                logging.exception(e)
+
+            all_qa_documents.extend(format_documents)
+
+    def _format_split_text(self, text):
+        regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
+        matches = re.findall(regex, text, re.UNICODE)
+
+        return [
+            {
+                "question": q,
+                "answer": re.sub(r"\n\s*", "\n", a.strip())
+            }
+            for q, a in matches if q and a
+        ]

+ 0 - 0
api/core/rag/models/__init__.py


+ 16 - 0
api/core/rag/models/document.py

@@ -0,0 +1,16 @@
+from typing import Optional
+
+from pydantic import BaseModel, Field
+
+
+class Document(BaseModel):
+    """Class for storing a piece of text and associated metadata."""
+
+    page_content: str
+
+    """Arbitrary metadata about the page content (e.g., source, relationships to other
+        documents, etc.).
+    """
+    metadata: Optional[dict] = Field(default_factory=dict)
+
+

+ 5 - 5
api/core/tool/web_reader_tool.py

@@ -21,9 +21,9 @@ from pydantic import BaseModel, Field
 from regex import regex
 
 from core.chain.llm_chain import LLMChain
-from core.data_loader import file_extractor
-from core.data_loader.file_extractor import FileExtractor
 from core.entities.application_entities import ModelConfigEntity
+from core.rag.extractor import extract_processor
+from core.rag.extractor.extract_processor import ExtractProcessor
 
 FULL_TEMPLATE = """
 TITLE: {title}
@@ -146,7 +146,7 @@ def get_url(url: str) -> str:
     headers = {
         "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
     }
-    supported_content_types = file_extractor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
+    supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
 
     head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10))
 
@@ -158,8 +158,8 @@ def get_url(url: str) -> str:
     if main_content_type not in supported_content_types:
         return "Unsupported content-type [{}] of URL.".format(main_content_type)
 
-    if main_content_type in file_extractor.SUPPORT_URL_CONTENT_TYPES:
-        return FileExtractor.load_from_url(url, return_text=True)
+    if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
+        return ExtractProcessor.load_from_url(url, return_text=True)
 
     response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30))
     a = extract_using_readabilipy(response.text)

+ 16 - 71
api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py

@@ -6,15 +6,12 @@ from langchain.tools import BaseTool
 from pydantic import BaseModel, Field
 
 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
-from core.embedding.cached_embedding import CacheEmbedding
-from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
-from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
+from core.rag.datasource.retrieval_service import RetrievalService
 from core.rerank.rerank import RerankRunner
 from extensions.ext_database import db
 from models.dataset import Dataset, Document, DocumentSegment
-from services.retrieval_service import RetrievalService
 
 default_retrieval_model = {
     'search_method': 'semantic_search',
@@ -174,76 +171,24 @@ class DatasetMultiRetrieverTool(BaseTool):
 
             if dataset.indexing_technique == "economy":
                 # use keyword table query
-                kw_table_index = KeywordTableIndex(
-                    dataset=dataset,
-                    config=KeywordTableConfig(
-                        max_keywords_per_chunk=5
-                    )
-                )
-
-                documents = kw_table_index.search(query, search_kwargs={'k': self.top_k})
+                documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
+                                                      dataset_id=dataset.id,
+                                                      query=query,
+                                                      top_k=self.top_k
+                                                      )
                 if documents:
                     all_documents.extend(documents)
             else:
-
-                try:
-                    model_manager = ModelManager()
-                    embedding_model = model_manager.get_model_instance(
-                        tenant_id=dataset.tenant_id,
-                        provider=dataset.embedding_model_provider,
-                        model_type=ModelType.TEXT_EMBEDDING,
-                        model=dataset.embedding_model
-                    )
-                except LLMBadRequestError:
-                    return []
-                except ProviderTokenNotInitError:
-                    return []
-
-                embeddings = CacheEmbedding(embedding_model)
-
-                documents = []
-                threads = []
                 if self.top_k > 0:
-                    # retrieval_model source with semantic
-                    if retrieval_model['search_method'] == 'semantic_search' or retrieval_model[
-                        'search_method'] == 'hybrid_search':
-                        embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
-                            'flask_app': current_app._get_current_object(),
-                            'dataset_id': str(dataset.id),
-                            'query': query,
-                            'top_k': self.top_k,
-                            'score_threshold': self.score_threshold,
-                            'reranking_model': None,
-                            'all_documents': documents,
-                            'search_method': 'hybrid_search',
-                            'embeddings': embeddings
-                        })
-                        threads.append(embedding_thread)
-                        embedding_thread.start()
-
-                    # retrieval_model source with full text
-                    if retrieval_model['search_method'] == 'full_text_search' or retrieval_model[
-                        'search_method'] == 'hybrid_search':
-                        full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search,
-                                                                  kwargs={
-                                                                      'flask_app': current_app._get_current_object(),
-                                                                      'dataset_id': str(dataset.id),
-                                                                      'query': query,
-                                                                      'search_method': 'hybrid_search',
-                                                                      'embeddings': embeddings,
-                                                                      'score_threshold': retrieval_model[
-                                                                          'score_threshold'] if retrieval_model[
-                                                                          'score_threshold_enabled'] else None,
-                                                                      'top_k': self.top_k,
-                                                                      'reranking_model': retrieval_model[
-                                                                          'reranking_model'] if retrieval_model[
-                                                                          'reranking_enable'] else None,
-                                                                      'all_documents': documents
-                                                                  })
-                        threads.append(full_text_index_thread)
-                        full_text_index_thread.start()
-
-                    for thread in threads:
-                        thread.join()
+                    # retrieval source
+                    documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
+                                                          dataset_id=dataset.id,
+                                                          query=query,
+                                                          top_k=self.top_k,
+                                                          score_threshold=retrieval_model['score_threshold']
+                                                          if retrieval_model['score_threshold_enabled'] else None,
+                                                          reranking_model=retrieval_model['reranking_model']
+                                                          if retrieval_model['reranking_enable'] else None
+                                                          )
 
                     all_documents.extend(documents)

+ 17 - 95
api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py

@@ -1,20 +1,12 @@
-import threading
 from typing import Optional
 
-from flask import current_app
 from langchain.tools import BaseTool
 from pydantic import BaseModel, Field
 
 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
-from core.embedding.cached_embedding import CacheEmbedding
-from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex
-from core.model_manager import ModelManager
-from core.model_runtime.entities.model_entities import ModelType
-from core.model_runtime.errors.invoke import InvokeAuthorizationError
-from core.rerank.rerank import RerankRunner
+from core.rag.datasource.retrieval_service import RetrievalService
 from extensions.ext_database import db
 from models.dataset import Dataset, Document, DocumentSegment
-from services.retrieval_service import RetrievalService
 
 default_retrieval_model = {
     'search_method': 'semantic_search',
@@ -77,94 +69,24 @@ class DatasetRetrieverTool(BaseTool):
         retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
         if dataset.indexing_technique == "economy":
             # use keyword table query
-            kw_table_index = KeywordTableIndex(
-                dataset=dataset,
-                config=KeywordTableConfig(
-                    max_keywords_per_chunk=5
-                )
-            )
-
-            documents = kw_table_index.search(query, search_kwargs={'k': self.top_k})
+            documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
+                                                  dataset_id=dataset.id,
+                                                  query=query,
+                                                  top_k=self.top_k
+                                                  )
             return str("\n".join([document.page_content for document in documents]))
         else:
-            # get embedding model instance
-            try:
-                model_manager = ModelManager()
-                embedding_model = model_manager.get_model_instance(
-                    tenant_id=dataset.tenant_id,
-                    provider=dataset.embedding_model_provider,
-                    model_type=ModelType.TEXT_EMBEDDING,
-                    model=dataset.embedding_model
-                )
-            except InvokeAuthorizationError:
-                return ''
-
-            embeddings = CacheEmbedding(embedding_model)
-
-            documents = []
-            threads = []
             if self.top_k > 0:
-                # retrieval source with semantic
-                if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
-                    embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
-                        'flask_app': current_app._get_current_object(),
-                        'dataset_id': str(dataset.id),
-                        'query': query,
-                        'top_k': self.top_k,
-                        'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
-                            'score_threshold_enabled'] else None,
-                        'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
-                            'reranking_enable'] else None,
-                        'all_documents': documents,
-                        'search_method': retrieval_model['search_method'],
-                        'embeddings': embeddings
-                    })
-                    threads.append(embedding_thread)
-                    embedding_thread.start()
-
-                # retrieval_model source with full text
-                if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
-                    full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
-                        'flask_app': current_app._get_current_object(),
-                        'dataset_id': str(dataset.id),
-                        'query': query,
-                        'search_method': retrieval_model['search_method'],
-                        'embeddings': embeddings,
-                        'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
-                            'score_threshold_enabled'] else None,
-                        'top_k': self.top_k,
-                        'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
-                            'reranking_enable'] else None,
-                        'all_documents': documents
-                    })
-                    threads.append(full_text_index_thread)
-                    full_text_index_thread.start()
-
-                for thread in threads:
-                    thread.join()
-
-                # hybrid search: rerank after all documents have been searched
-                if retrieval_model['search_method'] == 'hybrid_search':
-                    # get rerank model instance
-                    try:
-                        model_manager = ModelManager()
-                        rerank_model_instance = model_manager.get_model_instance(
-                            tenant_id=dataset.tenant_id,
-                            provider=retrieval_model['reranking_model']['reranking_provider_name'],
-                            model_type=ModelType.RERANK,
-                            model=retrieval_model['reranking_model']['reranking_model_name']
-                        )
-                    except InvokeAuthorizationError:
-                        return ''
-
-                    rerank_runner = RerankRunner(rerank_model_instance)
-                    documents = rerank_runner.run(
-                        query=query,
-                        documents=documents,
-                        score_threshold=retrieval_model['score_threshold'] if retrieval_model[
-                            'score_threshold_enabled'] else None,
-                        top_n=self.top_k
-                    )
+                # retrieval source
+                documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
+                                                      dataset_id=dataset.id,
+                                                      query=query,
+                                                      top_k=self.top_k,
+                                                      score_threshold=retrieval_model['score_threshold']
+                                                      if retrieval_model['score_threshold_enabled'] else None,
+                                                      reranking_model=retrieval_model['reranking_model']
+                                                      if retrieval_model['reranking_enable'] else None
+                                                      )
             else:
                 documents = []
 
@@ -234,4 +156,4 @@ class DatasetRetrieverTool(BaseTool):
             return str("\n".join(document_context_list))
 
     async def _arun(self, tool_input: str) -> str:
-        raise NotImplementedError()
+        raise NotImplementedError()

+ 5 - 5
api/core/tools/utils/web_reader_tool.py

@@ -21,9 +21,9 @@ from pydantic import BaseModel, Field
 from regex import regex
 
 from core.chain.llm_chain import LLMChain
-from core.data_loader import file_extractor
-from core.data_loader.file_extractor import FileExtractor
 from core.entities.application_entities import ModelConfigEntity
+from core.rag.extractor import extract_processor
+from core.rag.extractor.extract_processor import ExtractProcessor
 
 FULL_TEMPLATE = """
 TITLE: {title}
@@ -149,7 +149,7 @@ def get_url(url: str, user_agent: str = None) -> str:
     if user_agent:
         headers["User-Agent"] = user_agent
     
-    supported_content_types = file_extractor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
+    supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
 
     head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10))
 
@@ -161,8 +161,8 @@ def get_url(url: str, user_agent: str = None) -> str:
     if main_content_type not in supported_content_types:
         return "Unsupported content-type [{}] of URL.".format(main_content_type)
 
-    if main_content_type in file_extractor.SUPPORT_URL_CONTENT_TYPES:
-        return FileExtractor.load_from_url(url, return_text=True)
+    if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
+        return ExtractProcessor.load_from_url(url, return_text=True)
 
     response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30))
     a = extract_using_readabilipy(response.text)

+ 0 - 56
api/core/vector_store/milvus_vector_store.py

@@ -1,56 +0,0 @@
-from core.vector_store.vector.milvus import Milvus
-
-
-class MilvusVectorStore(Milvus):
-    def del_texts(self, where_filter: dict):
-        if not where_filter:
-            raise ValueError('where_filter must not be empty')
-
-        self.col.delete(where_filter.get('filter'))
-
-    def del_text(self, uuid: str) -> None:
-        expr = f"id == {uuid}"
-        self.col.delete(expr)
-
-    def text_exists(self, uuid: str) -> bool:
-        result = self.col.query(
-            expr=f'metadata["doc_id"] == "{uuid}"',
-            output_fields=["id"]
-        )
-
-        return len(result) > 0
-
-    def get_ids_by_document_id(self, document_id: str):
-        result = self.col.query(
-            expr=f'metadata["document_id"] == "{document_id}"',
-            output_fields=["id"]
-        )
-        if result:
-            return [item["id"] for item in result]
-        else:
-            return None
-
-    def get_ids_by_metadata_field(self, key: str, value: str):
-        result = self.col.query(
-            expr=f'metadata["{key}"] == "{value}"',
-            output_fields=["id"]
-        )
-        if result:
-            return [item["id"] for item in result]
-        else:
-            return None
-
-    def get_ids_by_doc_ids(self, doc_ids: list):
-        result = self.col.query(
-            expr=f'metadata["doc_id"] in {doc_ids}',
-            output_fields=["id"]
-        )
-        if result:
-            return [item["id"] for item in result]
-        else:
-            return None
-
-    def delete(self):
-        from pymilvus import utility
-        utility.drop_collection(self.collection_name, None, self.alias)
-

+ 0 - 76
api/core/vector_store/qdrant_vector_store.py

@@ -1,76 +0,0 @@
-from typing import Any, cast
-
-from langchain.schema import Document
-from qdrant_client.http.models import Filter, FilterSelector, PointIdsList
-from qdrant_client.local.qdrant_local import QdrantLocal
-
-from core.vector_store.vector.qdrant import Qdrant
-
-
-class QdrantVectorStore(Qdrant):
-    def del_texts(self, filter: Filter):
-        if not filter:
-            raise ValueError('filter must not be empty')
-
-        self._reload_if_needed()
-
-        self.client.delete(
-            collection_name=self.collection_name,
-            points_selector=FilterSelector(
-                filter=filter
-            ),
-        )
-
-    def del_text(self, uuid: str) -> None:
-        self._reload_if_needed()
-
-        self.client.delete(
-            collection_name=self.collection_name,
-            points_selector=PointIdsList(
-                points=[uuid],
-            ),
-        )
-
-    def text_exists(self, uuid: str) -> bool:
-        self._reload_if_needed()
-
-        response = self.client.retrieve(
-            collection_name=self.collection_name,
-            ids=[uuid]
-        )
-
-        return len(response) > 0
-
-    def delete(self):
-        self._reload_if_needed()
-
-        self.client.delete_collection(collection_name=self.collection_name)
-
-    def delete_group(self):
-        self._reload_if_needed()
-
-        self.client.delete_collection(collection_name=self.collection_name)
-
-    @classmethod
-    def _document_from_scored_point(
-            cls,
-            scored_point: Any,
-            content_payload_key: str,
-            metadata_payload_key: str,
-    ) -> Document:
-        if scored_point.payload.get('doc_id'):
-            return Document(
-                page_content=scored_point.payload.get(content_payload_key),
-                metadata={'doc_id': scored_point.id}
-            )
-
-        return Document(
-            page_content=scored_point.payload.get(content_payload_key),
-            metadata=scored_point.payload.get(metadata_payload_key) or {},
-        )
-
-    def _reload_if_needed(self):
-        if isinstance(self.client, QdrantLocal):
-            self.client = cast(QdrantLocal, self.client)
-            self.client._load()
-

+ 0 - 852
api/core/vector_store/vector/milvus.py

@@ -1,852 +0,0 @@
-"""Wrapper around the Milvus vector database."""
-from __future__ import annotations
-
-import logging
-from collections.abc import Iterable, Sequence
-from typing import Any, Optional, Union
-from uuid import uuid4
-
-import numpy as np
-from langchain.docstore.document import Document
-from langchain.embeddings.base import Embeddings
-from langchain.vectorstores.base import VectorStore
-from langchain.vectorstores.utils import maximal_marginal_relevance
-
-logger = logging.getLogger(__name__)
-
-DEFAULT_MILVUS_CONNECTION = {
-    "host": "localhost",
-    "port": "19530",
-    "user": "",
-    "password": "",
-    "secure": False,
-}
-
-
-class Milvus(VectorStore):
-    """Initialize wrapper around the milvus vector database.
-
-    In order to use this you need to have `pymilvus` installed and a
-    running Milvus
-
-    See the following documentation for how to run a Milvus instance:
-    https://milvus.io/docs/install_standalone-docker.md
-
-    If looking for a hosted Milvus, take a look at this documentation:
-    https://zilliz.com/cloud and make use of the Zilliz vectorstore found in
-    this project,
-
-    IF USING L2/IP metric IT IS HIGHLY SUGGESTED TO NORMALIZE YOUR DATA.
-
-    Args:
-        embedding_function (Embeddings): Function used to embed the text.
-        collection_name (str): Which Milvus collection to use. Defaults to
-            "LangChainCollection".
-        connection_args (Optional[dict[str, any]]): The connection args used for
-            this class comes in the form of a dict.
-        consistency_level (str): The consistency level to use for a collection.
-            Defaults to "Session".
-        index_params (Optional[dict]): Which index params to use. Defaults to
-            HNSW/AUTOINDEX depending on service.
-        search_params (Optional[dict]): Which search params to use. Defaults to
-            default of index.
-        drop_old (Optional[bool]): Whether to drop the current collection. Defaults
-            to False.
-
-    The connection args used for this class comes in the form of a dict,
-    here are a few of the options:
-        address (str): The actual address of Milvus
-            instance. Example address: "localhost:19530"
-        uri (str): The uri of Milvus instance. Example uri:
-            "http://randomwebsite:19530",
-            "tcp:foobarsite:19530",
-            "https://ok.s3.south.com:19530".
-        host (str): The host of Milvus instance. Default at "localhost",
-            PyMilvus will fill in the default host if only port is provided.
-        port (str/int): The port of Milvus instance. Default at 19530, PyMilvus
-            will fill in the default port if only host is provided.
-        user (str): Use which user to connect to Milvus instance. If user and
-            password are provided, we will add related header in every RPC call.
-        password (str): Required when user is provided. The password
-            corresponding to the user.
-        secure (bool): Default is false. If set to true, tls will be enabled.
-        client_key_path (str): If use tls two-way authentication, need to
-            write the client.key path.
-        client_pem_path (str): If use tls two-way authentication, need to
-            write the client.pem path.
-        ca_pem_path (str): If use tls two-way authentication, need to write
-            the ca.pem path.
-        server_pem_path (str): If use tls one-way authentication, need to
-            write the server.pem path.
-        server_name (str): If use tls, need to write the common name.
-
-    Example:
-        .. code-block:: python
-
-        from langchain import Milvus
-        from langchain.embeddings import OpenAIEmbeddings
-
-        embedding = OpenAIEmbeddings()
-        # Connect to a milvus instance on localhost
-        milvus_store = Milvus(
-            embedding_function = Embeddings,
-            collection_name = "LangChainCollection",
-            drop_old = True,
-        )
-
-    Raises:
-        ValueError: If the pymilvus python package is not installed.
-    """
-
-    def __init__(
-        self,
-        embedding_function: Embeddings,
-        collection_name: str = "LangChainCollection",
-        connection_args: Optional[dict[str, Any]] = None,
-        consistency_level: str = "Session",
-        index_params: Optional[dict] = None,
-        search_params: Optional[dict] = None,
-        drop_old: Optional[bool] = False,
-    ):
-        """Initialize the Milvus vector store."""
-        try:
-            from pymilvus import Collection, utility
-        except ImportError:
-            raise ValueError(
-                "Could not import pymilvus python package. "
-                "Please install it with `pip install pymilvus`."
-            )
-
-        # Default search params when one is not provided.
-        self.default_search_params = {
-            "IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}},
-            "IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}},
-            "IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}},
-            "HNSW": {"metric_type": "L2", "params": {"ef": 10}},
-            "RHNSW_FLAT": {"metric_type": "L2", "params": {"ef": 10}},
-            "RHNSW_SQ": {"metric_type": "L2", "params": {"ef": 10}},
-            "RHNSW_PQ": {"metric_type": "L2", "params": {"ef": 10}},
-            "IVF_HNSW": {"metric_type": "L2", "params": {"nprobe": 10, "ef": 10}},
-            "ANNOY": {"metric_type": "L2", "params": {"search_k": 10}},
-            "AUTOINDEX": {"metric_type": "L2", "params": {}},
-        }
-
-        self.embedding_func = embedding_function
-        self.collection_name = collection_name
-        self.index_params = index_params
-        self.search_params = search_params
-        self.consistency_level = consistency_level
-
-        # In order for a collection to be compatible, pk needs to be auto'id and int
-        self._primary_field = "id"
-        # In order for compatibility, the text field will need to be called "text"
-        self._text_field = "page_content"
-        # In order for compatibility, the vector field needs to be called "vector"
-        self._vector_field = "vectors"
-        # In order for compatibility, the metadata field will need to be called "metadata"
-        self._metadata_field = "metadata"
-        self.fields: list[str] = []
-        # Create the connection to the server
-        if connection_args is None:
-            connection_args = DEFAULT_MILVUS_CONNECTION
-        self.alias = self._create_connection_alias(connection_args)
-        self.col: Optional[Collection] = None
-
-        # Grab the existing collection if it exists
-        if utility.has_collection(self.collection_name, using=self.alias):
-            self.col = Collection(
-                self.collection_name,
-                using=self.alias,
-            )
-        # If need to drop old, drop it
-        if drop_old and isinstance(self.col, Collection):
-            self.col.drop()
-            self.col = None
-
-        # Initialize the vector store
-        self._init()
-
-    @property
-    def embeddings(self) -> Embeddings:
-        return self.embedding_func
-
-    def _create_connection_alias(self, connection_args: dict) -> str:
-        """Create the connection to the Milvus server."""
-        from pymilvus import MilvusException, connections
-
-        # Grab the connection arguments that are used for checking existing connection
-        host: str = connection_args.get("host", None)
-        port: Union[str, int] = connection_args.get("port", None)
-        address: str = connection_args.get("address", None)
-        uri: str = connection_args.get("uri", None)
-        user = connection_args.get("user", None)
-
-        # Order of use is host/port, uri, address
-        if host is not None and port is not None:
-            given_address = str(host) + ":" + str(port)
-        elif uri is not None:
-            given_address = uri.split("https://")[1]
-        elif address is not None:
-            given_address = address
-        else:
-            given_address = None
-            logger.debug("Missing standard address type for reuse atttempt")
-
-        # User defaults to empty string when getting connection info
-        if user is not None:
-            tmp_user = user
-        else:
-            tmp_user = ""
-
-        # If a valid address was given, then check if a connection exists
-        if given_address is not None:
-            for con in connections.list_connections():
-                addr = connections.get_connection_addr(con[0])
-                if (
-                    con[1]
-                    and ("address" in addr)
-                    and (addr["address"] == given_address)
-                    and ("user" in addr)
-                    and (addr["user"] == tmp_user)
-                ):
-                    logger.debug("Using previous connection: %s", con[0])
-                    return con[0]
-
-        # Generate a new connection if one doesn't exist
-        alias = uuid4().hex
-        try:
-            connections.connect(alias=alias, **connection_args)
-            logger.debug("Created new connection using: %s", alias)
-            return alias
-        except MilvusException as e:
-            logger.error("Failed to create new connection using: %s", alias)
-            raise e
-
-    def _init(
-        self, embeddings: Optional[list] = None, metadatas: Optional[list[dict]] = None
-    ) -> None:
-        if embeddings is not None:
-            self._create_collection(embeddings, metadatas)
-        self._extract_fields()
-        self._create_index()
-        self._create_search_params()
-        self._load()
-
-    def _create_collection(
-        self, embeddings: list, metadatas: Optional[list[dict]] = None
-    ) -> None:
-        from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, MilvusException
-        from pymilvus.orm.types import infer_dtype_bydata
-
-        # Determine embedding dim
-        dim = len(embeddings[0])
-        fields = []
-        # Determine metadata schema
-        # if metadatas:
-        #     # Create FieldSchema for each entry in metadata.
-        #     for key, value in metadatas[0].items():
-        #         # Infer the corresponding datatype of the metadata
-        #         dtype = infer_dtype_bydata(value)
-        #         # Datatype isn't compatible
-        #         if dtype == DataType.UNKNOWN or dtype == DataType.NONE:
-        #             logger.error(
-        #                 "Failure to create collection, unrecognized dtype for key: %s",
-        #                 key,
-        #             )
-        #             raise ValueError(f"Unrecognized datatype for {key}.")
-        #         # Dataype is a string/varchar equivalent
-        #         elif dtype == DataType.VARCHAR:
-        #             fields.append(FieldSchema(key, DataType.VARCHAR, max_length=65_535))
-        #         else:
-        #             fields.append(FieldSchema(key, dtype))
-        if metadatas:
-            fields.append(FieldSchema(self._metadata_field, DataType.JSON, max_length=65_535))
-
-        # Create the text field
-        fields.append(
-            FieldSchema(self._text_field, DataType.VARCHAR, max_length=65_535)
-        )
-        # Create the primary key field
-        fields.append(
-            FieldSchema(
-                self._primary_field, DataType.INT64, is_primary=True, auto_id=True
-            )
-        )
-        # Create the vector field, supports binary or float vectors
-        fields.append(
-            FieldSchema(self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim)
-        )
-
-        # Create the schema for the collection
-        schema = CollectionSchema(fields)
-
-        # Create the collection
-        try:
-            self.col = Collection(
-                name=self.collection_name,
-                schema=schema,
-                consistency_level=self.consistency_level,
-                using=self.alias,
-            )
-        except MilvusException as e:
-            logger.error(
-                "Failed to create collection: %s error: %s", self.collection_name, e
-            )
-            raise e
-
-    def _extract_fields(self) -> None:
-        """Grab the existing fields from the Collection"""
-        from pymilvus import Collection
-
-        if isinstance(self.col, Collection):
-            schema = self.col.schema
-            for x in schema.fields:
-                self.fields.append(x.name)
-            # Since primary field is auto-id, no need to track it
-            self.fields.remove(self._primary_field)
-
-    def _get_index(self) -> Optional[dict[str, Any]]:
-        """Return the vector index information if it exists"""
-        from pymilvus import Collection
-
-        if isinstance(self.col, Collection):
-            for x in self.col.indexes:
-                if x.field_name == self._vector_field:
-                    return x.to_dict()
-        return None
-
-    def _create_index(self) -> None:
-        """Create a index on the collection"""
-        from pymilvus import Collection, MilvusException
-
-        if isinstance(self.col, Collection) and self._get_index() is None:
-            try:
-                # If no index params, use a default HNSW based one
-                if self.index_params is None:
-                    self.index_params = {
-                        "metric_type": "IP",
-                        "index_type": "HNSW",
-                        "params": {"M": 8, "efConstruction": 64},
-                    }
-
-                try:
-                    self.col.create_index(
-                        self._vector_field,
-                        index_params=self.index_params,
-                        using=self.alias,
-                    )
-
-                # If default did not work, most likely on Zilliz Cloud
-                except MilvusException:
-                    # Use AUTOINDEX based index
-                    self.index_params = {
-                        "metric_type": "L2",
-                        "index_type": "AUTOINDEX",
-                        "params": {},
-                    }
-                    self.col.create_index(
-                        self._vector_field,
-                        index_params=self.index_params,
-                        using=self.alias,
-                    )
-                logger.debug(
-                    "Successfully created an index on collection: %s",
-                    self.collection_name,
-                )
-
-            except MilvusException as e:
-                logger.error(
-                    "Failed to create an index on collection: %s", self.collection_name
-                )
-                raise e
-
-    def _create_search_params(self) -> None:
-        """Generate search params based on the current index type"""
-        from pymilvus import Collection
-
-        if isinstance(self.col, Collection) and self.search_params is None:
-            index = self._get_index()
-            if index is not None:
-                index_type: str = index["index_param"]["index_type"]
-                metric_type: str = index["index_param"]["metric_type"]
-                self.search_params = self.default_search_params[index_type]
-                self.search_params["metric_type"] = metric_type
-
-    def _load(self) -> None:
-        """Load the collection if available."""
-        from pymilvus import Collection
-
-        if isinstance(self.col, Collection) and self._get_index() is not None:
-            self.col.load()
-
-    def add_texts(
-        self,
-        texts: Iterable[str],
-        metadatas: Optional[list[dict]] = None,
-        timeout: Optional[int] = None,
-        batch_size: int = 1000,
-        **kwargs: Any,
-    ) -> list[str]:
-        """Insert text data into Milvus.
-
-        Inserting data when the collection has not be made yet will result
-        in creating a new Collection. The data of the first entity decides
-        the schema of the new collection, the dim is extracted from the first
-        embedding and the columns are decided by the first metadata dict.
-        Metada keys will need to be present for all inserted values. At
-        the moment there is no None equivalent in Milvus.
-
-        Args:
-            texts (Iterable[str]): The texts to embed, it is assumed
-                that they all fit in memory.
-            metadatas (Optional[List[dict]]): Metadata dicts attached to each of
-                the texts. Defaults to None.
-            timeout (Optional[int]): Timeout for each batch insert. Defaults
-                to None.
-            batch_size (int, optional): Batch size to use for insertion.
-                Defaults to 1000.
-
-        Raises:
-            MilvusException: Failure to add texts
-
-        Returns:
-            List[str]: The resulting keys for each inserted element.
-        """
-        from pymilvus import Collection, MilvusException
-
-        texts = list(texts)
-
-        try:
-            embeddings = self.embedding_func.embed_documents(texts)
-        except NotImplementedError:
-            embeddings = [self.embedding_func.embed_query(x) for x in texts]
-
-        if len(embeddings) == 0:
-            logger.debug("Nothing to insert, skipping.")
-            return []
-
-        # If the collection hasn't been initialized yet, perform all steps to do so
-        if not isinstance(self.col, Collection):
-            self._init(embeddings, metadatas)
-
-        # Dict to hold all insert columns
-        insert_dict: dict[str, list] = {
-            self._text_field: texts,
-            self._vector_field: embeddings,
-        }
-
-        # Collect the metadata into the insert dict.
-        # if metadatas is not None:
-        #     for d in metadatas:
-        #         for key, value in d.items():
-        #             if key in self.fields:
-        #                 insert_dict.setdefault(key, []).append(value)
-        if metadatas is not None:
-            for d in metadatas:
-                insert_dict.setdefault(self._metadata_field, []).append(d)
-
-        # Total insert count
-        vectors: list = insert_dict[self._vector_field]
-        total_count = len(vectors)
-
-        pks: list[str] = []
-
-        assert isinstance(self.col, Collection)
-        for i in range(0, total_count, batch_size):
-            # Grab end index
-            end = min(i + batch_size, total_count)
-            # Convert dict to list of lists batch for insertion
-            insert_list = [insert_dict[x][i:end] for x in self.fields]
-            # Insert into the collection.
-            try:
-                res: Collection
-                res = self.col.insert(insert_list, timeout=timeout, **kwargs)
-                pks.extend(res.primary_keys)
-            except MilvusException as e:
-                logger.error(
-                    "Failed to insert batch starting at entity: %s/%s", i, total_count
-                )
-                raise e
-        return pks
-
-    def similarity_search(
-        self,
-        query: str,
-        k: int = 4,
-        param: Optional[dict] = None,
-        expr: Optional[str] = None,
-        timeout: Optional[int] = None,
-        **kwargs: Any,
-    ) -> list[Document]:
-        """Perform a similarity search against the query string.
-
-        Args:
-            query (str): The text to search.
-            k (int, optional): How many results to return. Defaults to 4.
-            param (dict, optional): The search params for the index type.
-                Defaults to None.
-            expr (str, optional): Filtering expression. Defaults to None.
-            timeout (int, optional): How long to wait before timeout error.
-                Defaults to None.
-            kwargs: Collection.search() keyword arguments.
-
-        Returns:
-            List[Document]: Document results for search.
-        """
-        if self.col is None:
-            logger.debug("No existing collection to search.")
-            return []
-        res = self.similarity_search_with_score(
-            query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs
-        )
-        return [doc for doc, _ in res]
-
-    def similarity_search_by_vector(
-        self,
-        embedding: list[float],
-        k: int = 4,
-        param: Optional[dict] = None,
-        expr: Optional[str] = None,
-        timeout: Optional[int] = None,
-        **kwargs: Any,
-    ) -> list[Document]:
-        """Perform a similarity search against the query string.
-
-        Args:
-            embedding (List[float]): The embedding vector to search.
-            k (int, optional): How many results to return. Defaults to 4.
-            param (dict, optional): The search params for the index type.
-                Defaults to None.
-            expr (str, optional): Filtering expression. Defaults to None.
-            timeout (int, optional): How long to wait before timeout error.
-                Defaults to None.
-            kwargs: Collection.search() keyword arguments.
-
-        Returns:
-            List[Document]: Document results for search.
-        """
-        if self.col is None:
-            logger.debug("No existing collection to search.")
-            return []
-        res = self.similarity_search_with_score_by_vector(
-            embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
-        )
-        return [doc for doc, _ in res]
-
-    def similarity_search_with_score(
-        self,
-        query: str,
-        k: int = 4,
-        param: Optional[dict] = None,
-        expr: Optional[str] = None,
-        timeout: Optional[int] = None,
-        **kwargs: Any,
-    ) -> list[tuple[Document, float]]:
-        """Perform a search on a query string and return results with score.
-
-        For more information about the search parameters, take a look at the pymilvus
-        documentation found here:
-        https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
-
-        Args:
-            query (str): The text being searched.
-            k (int, optional): The amount of results to return. Defaults to 4.
-            param (dict): The search params for the specified index.
-                Defaults to None.
-            expr (str, optional): Filtering expression. Defaults to None.
-            timeout (int, optional): How long to wait before timeout error.
-                Defaults to None.
-            kwargs: Collection.search() keyword arguments.
-
-        Returns:
-            List[float], List[Tuple[Document, any, any]]:
-        """
-        if self.col is None:
-            logger.debug("No existing collection to search.")
-            return []
-
-        # Embed the query text.
-        embedding = self.embedding_func.embed_query(query)
-
-        res = self.similarity_search_with_score_by_vector(
-            embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
-        )
-        return res
-
-    def _similarity_search_with_relevance_scores(
-        self,
-        query: str,
-        k: int = 4,
-        **kwargs: Any,
-    ) -> list[tuple[Document, float]]:
-        """Return docs and relevance scores in the range [0, 1].
-
-        0 is dissimilar, 1 is most similar.
-
-        Args:
-            query: input text
-            k: Number of Documents to return. Defaults to 4.
-            **kwargs: kwargs to be passed to similarity search. Should include:
-                score_threshold: Optional, a floating point value between 0 to 1 to
-                    filter the resulting set of retrieved docs
-
-        Returns:
-            List of Tuples of (doc, similarity_score)
-        """
-        return self.similarity_search_with_score(query, k, **kwargs)
-
-    def similarity_search_with_score_by_vector(
-        self,
-        embedding: list[float],
-        k: int = 4,
-        param: Optional[dict] = None,
-        expr: Optional[str] = None,
-        timeout: Optional[int] = None,
-        **kwargs: Any,
-    ) -> list[tuple[Document, float]]:
-        """Perform a search on a query string and return results with score.
-
-        For more information about the search parameters, take a look at the pymilvus
-        documentation found here:
-        https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
-
-        Args:
-            embedding (List[float]): The embedding vector being searched.
-            k (int, optional): The amount of results to return. Defaults to 4.
-            param (dict): The search params for the specified index.
-                Defaults to None.
-            expr (str, optional): Filtering expression. Defaults to None.
-            timeout (int, optional): How long to wait before timeout error.
-                Defaults to None.
-            kwargs: Collection.search() keyword arguments.
-
-        Returns:
-            List[Tuple[Document, float]]: Result doc and score.
-        """
-        if self.col is None:
-            logger.debug("No existing collection to search.")
-            return []
-
-        if param is None:
-            param = self.search_params
-
-        # Determine result metadata fields.
-        output_fields = self.fields[:]
-        output_fields.remove(self._vector_field)
-
-        # Perform the search.
-        res = self.col.search(
-            data=[embedding],
-            anns_field=self._vector_field,
-            param=param,
-            limit=k,
-            expr=expr,
-            output_fields=output_fields,
-            timeout=timeout,
-            **kwargs,
-        )
-        # Organize results.
-        ret = []
-        for result in res[0]:
-            meta = {x: result.entity.get(x) for x in output_fields}
-            doc = Document(page_content=meta.pop(self._text_field), metadata=meta.get('metadata'))
-            pair = (doc, result.score)
-            ret.append(pair)
-
-        return ret
-
-    def max_marginal_relevance_search(
-        self,
-        query: str,
-        k: int = 4,
-        fetch_k: int = 20,
-        lambda_mult: float = 0.5,
-        param: Optional[dict] = None,
-        expr: Optional[str] = None,
-        timeout: Optional[int] = None,
-        **kwargs: Any,
-    ) -> list[Document]:
-        """Perform a search and return results that are reordered by MMR.
-
-        Args:
-            query (str): The text being searched.
-            k (int, optional): How many results to give. Defaults to 4.
-            fetch_k (int, optional): Total results to select k from.
-                Defaults to 20.
-            lambda_mult: Number between 0 and 1 that determines the degree
-                        of diversity among the results with 0 corresponding
-                        to maximum diversity and 1 to minimum diversity.
-                        Defaults to 0.5
-            param (dict, optional): The search params for the specified index.
-                Defaults to None.
-            expr (str, optional): Filtering expression. Defaults to None.
-            timeout (int, optional): How long to wait before timeout error.
-                Defaults to None.
-            kwargs: Collection.search() keyword arguments.
-
-
-        Returns:
-            List[Document]: Document results for search.
-        """
-        if self.col is None:
-            logger.debug("No existing collection to search.")
-            return []
-
-        embedding = self.embedding_func.embed_query(query)
-
-        return self.max_marginal_relevance_search_by_vector(
-            embedding=embedding,
-            k=k,
-            fetch_k=fetch_k,
-            lambda_mult=lambda_mult,
-            param=param,
-            expr=expr,
-            timeout=timeout,
-            **kwargs,
-        )
-
-    def max_marginal_relevance_search_by_vector(
-        self,
-        embedding: list[float],
-        k: int = 4,
-        fetch_k: int = 20,
-        lambda_mult: float = 0.5,
-        param: Optional[dict] = None,
-        expr: Optional[str] = None,
-        timeout: Optional[int] = None,
-        **kwargs: Any,
-    ) -> list[Document]:
-        """Perform a search and return results that are reordered by MMR.
-
-        Args:
-            embedding (str): The embedding vector being searched.
-            k (int, optional): How many results to give. Defaults to 4.
-            fetch_k (int, optional): Total results to select k from.
-                Defaults to 20.
-            lambda_mult: Number between 0 and 1 that determines the degree
-                        of diversity among the results with 0 corresponding
-                        to maximum diversity and 1 to minimum diversity.
-                        Defaults to 0.5
-            param (dict, optional): The search params for the specified index.
-                Defaults to None.
-            expr (str, optional): Filtering expression. Defaults to None.
-            timeout (int, optional): How long to wait before timeout error.
-                Defaults to None.
-            kwargs: Collection.search() keyword arguments.
-
-        Returns:
-            List[Document]: Document results for search.
-        """
-        if self.col is None:
-            logger.debug("No existing collection to search.")
-            return []
-
-        if param is None:
-            param = self.search_params
-
-        # Determine result metadata fields.
-        output_fields = self.fields[:]
-        output_fields.remove(self._vector_field)
-
-        # Perform the search.
-        res = self.col.search(
-            data=[embedding],
-            anns_field=self._vector_field,
-            param=param,
-            limit=fetch_k,
-            expr=expr,
-            output_fields=output_fields,
-            timeout=timeout,
-            **kwargs,
-        )
-        # Organize results.
-        ids = []
-        documents = []
-        scores = []
-        for result in res[0]:
-            meta = {x: result.entity.get(x) for x in output_fields}
-            doc = Document(page_content=meta.pop(self._text_field), metadata=meta)
-            documents.append(doc)
-            scores.append(result.score)
-            ids.append(result.id)
-
-        vectors = self.col.query(
-            expr=f"{self._primary_field} in {ids}",
-            output_fields=[self._primary_field, self._vector_field],
-            timeout=timeout,
-        )
-        # Reorganize the results from query to match search order.
-        vectors = {x[self._primary_field]: x[self._vector_field] for x in vectors}
-
-        ordered_result_embeddings = [vectors[x] for x in ids]
-
-        # Get the new order of results.
-        new_ordering = maximal_marginal_relevance(
-            np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult
-        )
-
-        # Reorder the values and return.
-        ret = []
-        for x in new_ordering:
-            # Function can return -1 index
-            if x == -1:
-                break
-            else:
-                ret.append(documents[x])
-        return ret
-
-    @classmethod
-    def from_texts(
-        cls,
-        texts: list[str],
-        embedding: Embeddings,
-        metadatas: Optional[list[dict]] = None,
-        collection_name: str = "LangChainCollection",
-        connection_args: dict[str, Any] = DEFAULT_MILVUS_CONNECTION,
-        consistency_level: str = "Session",
-        index_params: Optional[dict] = None,
-        search_params: Optional[dict] = None,
-        drop_old: bool = False,
-        batch_size: int = 100,
-        ids: Optional[Sequence[str]] = None,
-        **kwargs: Any,
-    ) -> Milvus:
-        """Create a Milvus collection, indexes it with HNSW, and insert data.
-
-        Args:
-            texts (List[str]): Text data.
-            embedding (Embeddings): Embedding function.
-            metadatas (Optional[List[dict]]): Metadata for each text if it exists.
-                Defaults to None.
-            collection_name (str, optional): Collection name to use. Defaults to
-                "LangChainCollection".
-            connection_args (dict[str, Any], optional): Connection args to use. Defaults
-                to DEFAULT_MILVUS_CONNECTION.
-            consistency_level (str, optional): Which consistency level to use. Defaults
-                to "Session".
-            index_params (Optional[dict], optional): Which index_params to use. Defaults
-                to None.
-            search_params (Optional[dict], optional): Which search params to use.
-                Defaults to None.
-            drop_old (Optional[bool], optional): Whether to drop the collection with
-                that name if it exists. Defaults to False.
-            batch_size:
-                How many vectors upload per-request.
-                Default: 100
-            ids: Optional[Sequence[str]] = None,
-
-        Returns:
-            Milvus: Milvus Vector Store
-        """
-        vector_db = cls(
-            embedding_function=embedding,
-            collection_name=collection_name,
-            connection_args=connection_args,
-            consistency_level=consistency_level,
-            index_params=index_params,
-            search_params=search_params,
-            drop_old=drop_old,
-            **kwargs,
-        )
-        vector_db.add_texts(texts=texts, metadatas=metadatas, batch_size=batch_size)
-        return vector_db

+ 0 - 1759
api/core/vector_store/vector/qdrant.py

@@ -1,1759 +0,0 @@
-"""Wrapper around Qdrant vector database."""
-from __future__ import annotations
-
-import asyncio
-import functools
-import uuid
-import warnings
-from collections.abc import Callable, Generator, Iterable, Sequence
-from itertools import islice
-from operator import itemgetter
-from typing import TYPE_CHECKING, Any, Optional, Union
-
-import numpy as np
-from langchain.docstore.document import Document
-from langchain.embeddings.base import Embeddings
-from langchain.vectorstores import VectorStore
-from langchain.vectorstores.utils import maximal_marginal_relevance
-from qdrant_client.http.models import PayloadSchemaType, TextIndexParams, TextIndexType, TokenizerType
-
-if TYPE_CHECKING:
-    from qdrant_client import grpc  # noqa
-    from qdrant_client.conversions import common_types
-    from qdrant_client.http import models as rest
-
-    DictFilter = dict[str, Union[str, int, bool, dict, list]]
-    MetadataFilter = Union[DictFilter, common_types.Filter]
-
-
-class QdrantException(Exception):
-    """Base class for all the Qdrant related exceptions"""
-
-
-def sync_call_fallback(method: Callable) -> Callable:
-    """
-    Decorator to call the synchronous method of the class if the async method is not
-    implemented. This decorator might be only used for the methods that are defined
-    as async in the class.
-    """
-
-    @functools.wraps(method)
-    async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
-        try:
-            return await method(self, *args, **kwargs)
-        except NotImplementedError:
-            # If the async method is not implemented, call the synchronous method
-            # by removing the first letter from the method name. For example,
-            # if the async method is called ``aaad_texts``, the synchronous method
-            # will be called ``aad_texts``.
-            sync_method = functools.partial(
-                getattr(self, method.__name__[1:]), *args, **kwargs
-            )
-            return await asyncio.get_event_loop().run_in_executor(None, sync_method)
-
-    return wrapper
-
-
-class Qdrant(VectorStore):
-    """Wrapper around Qdrant vector database.
-
-    To use you should have the ``qdrant-client`` package installed.
-
-    Example:
-        .. code-block:: python
-
-            from qdrant_client import QdrantClient
-            from langchain import Qdrant
-
-            client = QdrantClient()
-            collection_name = "MyCollection"
-            qdrant = Qdrant(client, collection_name, embedding_function)
-    """
-
-    CONTENT_KEY = "page_content"
-    METADATA_KEY = "metadata"
-    GROUP_KEY = "group_id"
-    VECTOR_NAME = None
-
-    def __init__(
-        self,
-        client: Any,
-        collection_name: str,
-        embeddings: Optional[Embeddings] = None,
-        content_payload_key: str = CONTENT_KEY,
-        metadata_payload_key: str = METADATA_KEY,
-        group_payload_key: str = GROUP_KEY,
-        group_id: str = None,
-        distance_strategy: str = "COSINE",
-        vector_name: Optional[str] = VECTOR_NAME,
-        embedding_function: Optional[Callable] = None,  # deprecated
-        is_new_collection: bool = False
-    ):
-        """Initialize with necessary components."""
-        try:
-            import qdrant_client
-        except ImportError:
-            raise ValueError(
-                "Could not import qdrant-client python package. "
-                "Please install it with `pip install qdrant-client`."
-            )
-
-        if not isinstance(client, qdrant_client.QdrantClient):
-            raise ValueError(
-                f"client should be an instance of qdrant_client.QdrantClient, "
-                f"got {type(client)}"
-            )
-
-        if embeddings is None and embedding_function is None:
-            raise ValueError(
-                "`embeddings` value can't be None. Pass `Embeddings` instance."
-            )
-
-        if embeddings is not None and embedding_function is not None:
-            raise ValueError(
-                "Both `embeddings` and `embedding_function` are passed. "
-                "Use `embeddings` only."
-            )
-
-        self._embeddings = embeddings
-        self._embeddings_function = embedding_function
-        self.client: qdrant_client.QdrantClient = client
-        self.collection_name = collection_name
-        self.content_payload_key = content_payload_key or self.CONTENT_KEY
-        self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY
-        self.group_payload_key = group_payload_key or self.GROUP_KEY
-        self.vector_name = vector_name or self.VECTOR_NAME
-        self.group_id = group_id
-        self.is_new_collection= is_new_collection
-
-        if embedding_function is not None:
-            warnings.warn(
-                "Using `embedding_function` is deprecated. "
-                "Pass `Embeddings` instance to `embeddings` instead."
-            )
-
-        if not isinstance(embeddings, Embeddings):
-            warnings.warn(
-                "`embeddings` should be an instance of `Embeddings`."
-                "Using `embeddings` as `embedding_function` which is deprecated"
-            )
-            self._embeddings_function = embeddings
-            self._embeddings = None
-
-        self.distance_strategy = distance_strategy.upper()
-
-    @property
-    def embeddings(self) -> Optional[Embeddings]:
-        return self._embeddings
-
-    def add_texts(
-        self,
-        texts: Iterable[str],
-        metadatas: Optional[list[dict]] = None,
-        ids: Optional[Sequence[str]] = None,
-        batch_size: int = 64,
-        **kwargs: Any,
-    ) -> list[str]:
-        """Run more texts through the embeddings and add to the vectorstore.
-
-        Args:
-            texts: Iterable of strings to add to the vectorstore.
-            metadatas: Optional list of metadatas associated with the texts.
-            ids:
-                Optional list of ids to associate with the texts. Ids have to be
-                uuid-like strings.
-            batch_size:
-                How many vectors upload per-request.
-                Default: 64
-            group_id:
-                collection group
-
-        Returns:
-            List of ids from adding the texts into the vectorstore.
-        """
-        added_ids = []
-        for batch_ids, points in self._generate_rest_batches(
-            texts, metadatas, ids, batch_size
-        ):
-            self.client.upsert(
-                collection_name=self.collection_name, points=points
-            )
-            added_ids.extend(batch_ids)
-        # if is new collection, create payload index on group_id
-        if self.is_new_collection:
-            # create payload index
-            self.client.create_payload_index(self.collection_name, self.group_payload_key,
-                                             field_schema=PayloadSchemaType.KEYWORD,
-                                             field_type=PayloadSchemaType.KEYWORD)
-            # creat full text index
-            text_index_params = TextIndexParams(
-                type=TextIndexType.TEXT,
-                tokenizer=TokenizerType.MULTILINGUAL,
-                min_token_len=2,
-                max_token_len=20,
-                lowercase=True
-            )
-            self.client.create_payload_index(self.collection_name, self.content_payload_key,
-                                             field_schema=text_index_params)
-        return added_ids
-
-    @sync_call_fallback
-    async def aadd_texts(
-        self,
-        texts: Iterable[str],
-        metadatas: Optional[list[dict]] = None,
-        ids: Optional[Sequence[str]] = None,
-        batch_size: int = 64,
-        **kwargs: Any,
-    ) -> list[str]:
-        """Run more texts through the embeddings and add to the vectorstore.
-
-        Args:
-            texts: Iterable of strings to add to the vectorstore.
-            metadatas: Optional list of metadatas associated with the texts.
-            ids:
-                Optional list of ids to associate with the texts. Ids have to be
-                uuid-like strings.
-            batch_size:
-                How many vectors upload per-request.
-                Default: 64
-
-        Returns:
-            List of ids from adding the texts into the vectorstore.
-        """
-        from qdrant_client import grpc  # noqa
-        from qdrant_client.conversions.conversion import RestToGrpc
-
-        added_ids = []
-        for batch_ids, points in self._generate_rest_batches(
-            texts, metadatas, ids, batch_size
-        ):
-            await self.client.async_grpc_points.Upsert(
-                grpc.UpsertPoints(
-                    collection_name=self.collection_name,
-                    points=[RestToGrpc.convert_point_struct(point) for point in points],
-                )
-            )
-            added_ids.extend(batch_ids)
-
-        return added_ids
-
-    def similarity_search(
-        self,
-        query: str,
-        k: int = 4,
-        filter: Optional[MetadataFilter] = None,
-        search_params: Optional[common_types.SearchParams] = None,
-        offset: int = 0,
-        score_threshold: Optional[float] = None,
-        consistency: Optional[common_types.ReadConsistency] = None,
-        **kwargs: Any,
-    ) -> list[tuple[Document, float]]:
-        """Return docs most similar to query.
-
-        Args:
-            query: Text to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            filter: Filter by metadata. Defaults to None.
-            search_params: Additional search params
-            offset:
-                Offset of the first result to return.
-                May be used to paginate results.
-                Note: large offset values may cause performance issues.
-            score_threshold:
-                Define a minimal score threshold for the result.
-                If defined, less similar results will not be returned.
-                Score of the returned result might be higher or smaller than the
-                threshold depending on the Distance function used.
-                E.g. for cosine similarity only higher scores will be returned.
-            consistency:
-                Read consistency of the search. Defines how many replicas should be
-                queried before returning the result.
-                Values:
-                - int - number of replicas to query, values should present in all
-                        queried replicas
-                - 'majority' - query all replicas, but return values present in the
-                               majority of replicas
-                - 'quorum' - query the majority of replicas, return values present in
-                             all of them
-                - 'all' - query all replicas, and return values present in all replicas
-
-        Returns:
-            List of Documents most similar to the query.
-        """
-        results = self.similarity_search_with_score(
-            query,
-            k,
-            filter=filter,
-            search_params=search_params,
-            offset=offset,
-            score_threshold=score_threshold,
-            consistency=consistency,
-            **kwargs,
-        )
-        return list(map(itemgetter(0), results))
-
-    @sync_call_fallback
-    async def asimilarity_search(
-        self,
-        query: str,
-        k: int = 4,
-        filter: Optional[MetadataFilter] = None,
-        **kwargs: Any,
-    ) -> list[Document]:
-        """Return docs most similar to query.
-        Args:
-            query: Text to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            filter: Filter by metadata. Defaults to None.
-        Returns:
-            List of Documents most similar to the query.
-        """
-        results = await self.asimilarity_search_with_score(query, k, filter, **kwargs)
-        return list(map(itemgetter(0), results))
-
-    def similarity_search_with_score(
-        self,
-        query: str,
-        k: int = 4,
-        filter: Optional[MetadataFilter] = None,
-        search_params: Optional[common_types.SearchParams] = None,
-        offset: int = 0,
-        score_threshold: Optional[float] = None,
-        consistency: Optional[common_types.ReadConsistency] = None,
-        **kwargs: Any,
-    ) -> list[tuple[Document, float]]:
-        """Return docs most similar to query.
-
-        Args:
-            query: Text to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            filter: Filter by metadata. Defaults to None.
-            search_params: Additional search params
-            offset:
-                Offset of the first result to return.
-                May be used to paginate results.
-                Note: large offset values may cause performance issues.
-            score_threshold:
-                Define a minimal score threshold for the result.
-                If defined, less similar results will not be returned.
-                Score of the returned result might be higher or smaller than the
-                threshold depending on the Distance function used.
-                E.g. for cosine similarity only higher scores will be returned.
-            consistency:
-                Read consistency of the search. Defines how many replicas should be
-                queried before returning the result.
-                Values:
-                - int - number of replicas to query, values should present in all
-                        queried replicas
-                - 'majority' - query all replicas, but return values present in the
-                               majority of replicas
-                - 'quorum' - query the majority of replicas, return values present in
-                             all of them
-                - 'all' - query all replicas, and return values present in all replicas
-
-        Returns:
-            List of documents most similar to the query text and distance for each.
-        """
-        return self.similarity_search_with_score_by_vector(
-            self._embed_query(query),
-            k,
-            filter=filter,
-            search_params=search_params,
-            offset=offset,
-            score_threshold=score_threshold,
-            consistency=consistency,
-            **kwargs,
-        )
-
-    @sync_call_fallback
-    async def asimilarity_search_with_score(
-        self,
-        query: str,
-        k: int = 4,
-        filter: Optional[MetadataFilter] = None,
-        search_params: Optional[common_types.SearchParams] = None,
-        offset: int = 0,
-        score_threshold: Optional[float] = None,
-        consistency: Optional[common_types.ReadConsistency] = None,
-        **kwargs: Any,
-    ) -> list[tuple[Document, float]]:
-        """Return docs most similar to query.
-
-        Args:
-            query: Text to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            filter: Filter by metadata. Defaults to None.
-            search_params: Additional search params
-            offset:
-                Offset of the first result to return.
-                May be used to paginate results.
-                Note: large offset values may cause performance issues.
-            score_threshold:
-                Define a minimal score threshold for the result.
-                If defined, less similar results will not be returned.
-                Score of the returned result might be higher or smaller than the
-                threshold depending on the Distance function used.
-                E.g. for cosine similarity only higher scores will be returned.
-            consistency:
-                Read consistency of the search. Defines how many replicas should be
-                queried before returning the result.
-                Values:
-                - int - number of replicas to query, values should present in all
-                        queried replicas
-                - 'majority' - query all replicas, but return values present in the
-                               majority of replicas
-                - 'quorum' - query the majority of replicas, return values present in
-                             all of them
-                - 'all' - query all replicas, and return values present in all replicas
-
-        Returns:
-            List of documents most similar to the query text and distance for each.
-        """
-        return await self.asimilarity_search_with_score_by_vector(
-            self._embed_query(query),
-            k,
-            filter=filter,
-            search_params=search_params,
-            offset=offset,
-            score_threshold=score_threshold,
-            consistency=consistency,
-            **kwargs,
-        )
-
-    def similarity_search_by_vector(
-        self,
-        embedding: list[float],
-        k: int = 4,
-        filter: Optional[MetadataFilter] = None,
-        search_params: Optional[common_types.SearchParams] = None,
-        offset: int = 0,
-        score_threshold: Optional[float] = None,
-        consistency: Optional[common_types.ReadConsistency] = None,
-        **kwargs: Any,
-    ) -> list[Document]:
-        """Return docs most similar to embedding vector.
-
-        Args:
-            embedding: Embedding vector to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            filter: Filter by metadata. Defaults to None.
-            search_params: Additional search params
-            offset:
-                Offset of the first result to return.
-                May be used to paginate results.
-                Note: large offset values may cause performance issues.
-            score_threshold:
-                Define a minimal score threshold for the result.
-                If defined, less similar results will not be returned.
-                Score of the returned result might be higher or smaller than the
-                threshold depending on the Distance function used.
-                E.g. for cosine similarity only higher scores will be returned.
-            consistency:
-                Read consistency of the search. Defines how many replicas should be
-                queried before returning the result.
-                Values:
-                - int - number of replicas to query, values should present in all
-                        queried replicas
-                - 'majority' - query all replicas, but return values present in the
-                               majority of replicas
-                - 'quorum' - query the majority of replicas, return values present in
-                             all of them
-                - 'all' - query all replicas, and return values present in all replicas
-
-        Returns:
-            List of Documents most similar to the query.
-        """
-        results = self.similarity_search_with_score_by_vector(
-            embedding,
-            k,
-            filter=filter,
-            search_params=search_params,
-            offset=offset,
-            score_threshold=score_threshold,
-            consistency=consistency,
-            **kwargs,
-        )
-        return list(map(itemgetter(0), results))
-
-    @sync_call_fallback
-    async def asimilarity_search_by_vector(
-        self,
-        embedding: list[float],
-        k: int = 4,
-        filter: Optional[MetadataFilter] = None,
-        search_params: Optional[common_types.SearchParams] = None,
-        offset: int = 0,
-        score_threshold: Optional[float] = None,
-        consistency: Optional[common_types.ReadConsistency] = None,
-        **kwargs: Any,
-    ) -> list[Document]:
-        """Return docs most similar to embedding vector.
-
-        Args:
-            embedding: Embedding vector to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            filter: Filter by metadata. Defaults to None.
-            search_params: Additional search params
-            offset:
-                Offset of the first result to return.
-                May be used to paginate results.
-                Note: large offset values may cause performance issues.
-            score_threshold:
-                Define a minimal score threshold for the result.
-                If defined, less similar results will not be returned.
-                Score of the returned result might be higher or smaller than the
-                threshold depending on the Distance function used.
-                E.g. for cosine similarity only higher scores will be returned.
-            consistency:
-                Read consistency of the search. Defines how many replicas should be
-                queried before returning the result.
-                Values:
-                - int - number of replicas to query, values should present in all
-                        queried replicas
-                - 'majority' - query all replicas, but return values present in the
-                               majority of replicas
-                - 'quorum' - query the majority of replicas, return values present in
-                             all of them
-                - 'all' - query all replicas, and return values present in all replicas
-
-        Returns:
-            List of Documents most similar to the query.
-        """
-        results = await self.asimilarity_search_with_score_by_vector(
-            embedding,
-            k,
-            filter=filter,
-            search_params=search_params,
-            offset=offset,
-            score_threshold=score_threshold,
-            consistency=consistency,
-            **kwargs,
-        )
-        return list(map(itemgetter(0), results))
-
-    def similarity_search_with_score_by_vector(
-        self,
-        embedding: list[float],
-        k: int = 4,
-        filter: Optional[MetadataFilter] = None,
-        search_params: Optional[common_types.SearchParams] = None,
-        offset: int = 0,
-        score_threshold: Optional[float] = None,
-        consistency: Optional[common_types.ReadConsistency] = None,
-        **kwargs: Any,
-    ) -> list[tuple[Document, float]]:
-        """Return docs most similar to embedding vector.
-
-        Args:
-            embedding: Embedding vector to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            filter: Filter by metadata. Defaults to None.
-            search_params: Additional search params
-            offset:
-                Offset of the first result to return.
-                May be used to paginate results.
-                Note: large offset values may cause performance issues.
-            score_threshold:
-                Define a minimal score threshold for the result.
-                If defined, less similar results will not be returned.
-                Score of the returned result might be higher or smaller than the
-                threshold depending on the Distance function used.
-                E.g. for cosine similarity only higher scores will be returned.
-            consistency:
-                Read consistency of the search. Defines how many replicas should be
-                queried before returning the result.
-                Values:
-                - int - number of replicas to query, values should present in all
-                        queried replicas
-                - 'majority' - query all replicas, but return values present in the
-                               majority of replicas
-                - 'quorum' - query the majority of replicas, return values present in
-                             all of them
-                - 'all' - query all replicas, and return values present in all replicas
-
-        Returns:
-            List of documents most similar to the query text and distance for each.
-        """
-        if filter is not None and isinstance(filter, dict):
-            warnings.warn(
-                "Using dict as a `filter` is deprecated. Please use qdrant-client "
-                "filters directly: "
-                "https://qdrant.tech/documentation/concepts/filtering/",
-                DeprecationWarning,
-            )
-            qdrant_filter = self._qdrant_filter_from_dict(filter)
-        else:
-            qdrant_filter = filter
-
-        query_vector = embedding
-        if self.vector_name is not None:
-            query_vector = (self.vector_name, embedding)  # type: ignore[assignment]
-
-        results = self.client.search(
-            collection_name=self.collection_name,
-            query_vector=query_vector,
-            query_filter=qdrant_filter,
-            search_params=search_params,
-            limit=k,
-            offset=offset,
-            with_payload=True,
-            with_vectors=True,
-            score_threshold=score_threshold,
-            consistency=consistency,
-            **kwargs,
-        )
-        return [
-            (
-                self._document_from_scored_point(
-                    result, self.content_payload_key, self.metadata_payload_key
-                ),
-                result.score,
-            )
-            for result in results
-        ]
-
-    def similarity_search_by_bm25(
-        self,
-        filter: Optional[MetadataFilter] = None,
-        k: int = 4
-    ) -> list[Document]:
-        """Return docs most similar by bm25.
-
-        Args:
-            embedding: Embedding vector to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            filter: Filter by metadata. Defaults to None.
-            search_params: Additional search params
-        Returns:
-            List of documents most similar to the query text and distance for each.
-        """
-        response = self.client.scroll(
-            collection_name=self.collection_name,
-            scroll_filter=filter,
-            limit=k,
-            with_payload=True,
-            with_vectors=True
-
-        )
-        results = response[0]
-        documents = []
-        for result in results:
-            if result:
-                documents.append(self._document_from_scored_point(
-                        result, self.content_payload_key, self.metadata_payload_key
-                    ))
-
-        return documents
-
-    @sync_call_fallback
-    async def asimilarity_search_with_score_by_vector(
-        self,
-        embedding: list[float],
-        k: int = 4,
-        filter: Optional[MetadataFilter] = None,
-        search_params: Optional[common_types.SearchParams] = None,
-        offset: int = 0,
-        score_threshold: Optional[float] = None,
-        consistency: Optional[common_types.ReadConsistency] = None,
-        **kwargs: Any,
-    ) -> list[tuple[Document, float]]:
-        """Return docs most similar to embedding vector.
-
-        Args:
-            embedding: Embedding vector to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            filter: Filter by metadata. Defaults to None.
-            search_params: Additional search params
-            offset:
-                Offset of the first result to return.
-                May be used to paginate results.
-                Note: large offset values may cause performance issues.
-            score_threshold:
-                Define a minimal score threshold for the result.
-                If defined, less similar results will not be returned.
-                Score of the returned result might be higher or smaller than the
-                threshold depending on the Distance function used.
-                E.g. for cosine similarity only higher scores will be returned.
-            consistency:
-                Read consistency of the search. Defines how many replicas should be
-                queried before returning the result.
-                Values:
-                - int - number of replicas to query, values should present in all
-                        queried replicas
-                - 'majority' - query all replicas, but return values present in the
-                               majority of replicas
-                - 'quorum' - query the majority of replicas, return values present in
-                             all of them
-                - 'all' - query all replicas, and return values present in all replicas
-
-        Returns:
-            List of documents most similar to the query text and distance for each.
-        """
-        from qdrant_client import grpc  # noqa
-        from qdrant_client.conversions.conversion import RestToGrpc
-        from qdrant_client.http import models as rest
-
-        if filter is not None and isinstance(filter, dict):
-            warnings.warn(
-                "Using dict as a `filter` is deprecated. Please use qdrant-client "
-                "filters directly: "
-                "https://qdrant.tech/documentation/concepts/filtering/",
-                DeprecationWarning,
-            )
-            qdrant_filter = self._qdrant_filter_from_dict(filter)
-        else:
-            qdrant_filter = filter
-
-        if qdrant_filter is not None and isinstance(qdrant_filter, rest.Filter):
-            qdrant_filter = RestToGrpc.convert_filter(qdrant_filter)
-
-        response = await self.client.async_grpc_points.Search(
-            grpc.SearchPoints(
-                collection_name=self.collection_name,
-                vector_name=self.vector_name,
-                vector=embedding,
-                filter=qdrant_filter,
-                params=search_params,
-                limit=k,
-                offset=offset,
-                with_payload=grpc.WithPayloadSelector(enable=True),
-                with_vectors=grpc.WithVectorsSelector(enable=False),
-                score_threshold=score_threshold,
-                read_consistency=consistency,
-                **kwargs,
-            )
-        )
-
-        return [
-            (
-                self._document_from_scored_point_grpc(
-                    result, self.content_payload_key, self.metadata_payload_key
-                ),
-                result.score,
-            )
-            for result in response.result
-        ]
-
-    def max_marginal_relevance_search(
-        self,
-        query: str,
-        k: int = 4,
-        fetch_k: int = 20,
-        lambda_mult: float = 0.5,
-        **kwargs: Any,
-    ) -> list[Document]:
-        """Return docs selected using the maximal marginal relevance.
-
-        Maximal marginal relevance optimizes for similarity to query AND diversity
-        among selected documents.
-
-        Args:
-            query: Text to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            fetch_k: Number of Documents to fetch to pass to MMR algorithm.
-                     Defaults to 20.
-            lambda_mult: Number between 0 and 1 that determines the degree
-                        of diversity among the results with 0 corresponding
-                        to maximum diversity and 1 to minimum diversity.
-                        Defaults to 0.5.
-        Returns:
-            List of Documents selected by maximal marginal relevance.
-        """
-        query_embedding = self._embed_query(query)
-        return self.max_marginal_relevance_search_by_vector(
-            query_embedding, k, fetch_k, lambda_mult, **kwargs
-        )
-
-    @sync_call_fallback
-    async def amax_marginal_relevance_search(
-        self,
-        query: str,
-        k: int = 4,
-        fetch_k: int = 20,
-        lambda_mult: float = 0.5,
-        **kwargs: Any,
-    ) -> list[Document]:
-        """Return docs selected using the maximal marginal relevance.
-
-        Maximal marginal relevance optimizes for similarity to query AND diversity
-        among selected documents.
-
-        Args:
-            query: Text to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            fetch_k: Number of Documents to fetch to pass to MMR algorithm.
-                     Defaults to 20.
-            lambda_mult: Number between 0 and 1 that determines the degree
-                        of diversity among the results with 0 corresponding
-                        to maximum diversity and 1 to minimum diversity.
-                        Defaults to 0.5.
-        Returns:
-            List of Documents selected by maximal marginal relevance.
-        """
-        query_embedding = self._embed_query(query)
-        return await self.amax_marginal_relevance_search_by_vector(
-            query_embedding, k, fetch_k, lambda_mult, **kwargs
-        )
-
-    def max_marginal_relevance_search_by_vector(
-        self,
-        embedding: list[float],
-        k: int = 4,
-        fetch_k: int = 20,
-        lambda_mult: float = 0.5,
-        **kwargs: Any,
-    ) -> list[Document]:
-        """Return docs selected using the maximal marginal relevance.
-
-        Maximal marginal relevance optimizes for similarity to query AND diversity
-        among selected documents.
-
-        Args:
-            embedding: Embedding to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            fetch_k: Number of Documents to fetch to pass to MMR algorithm.
-            lambda_mult: Number between 0 and 1 that determines the degree
-                        of diversity among the results with 0 corresponding
-                        to maximum diversity and 1 to minimum diversity.
-                        Defaults to 0.5.
-        Returns:
-            List of Documents selected by maximal marginal relevance.
-        """
-        results = self.max_marginal_relevance_search_with_score_by_vector(
-            embedding=embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, **kwargs
-        )
-        return list(map(itemgetter(0), results))
-
-    @sync_call_fallback
-    async def amax_marginal_relevance_search_by_vector(
-        self,
-        embedding: list[float],
-        k: int = 4,
-        fetch_k: int = 20,
-        lambda_mult: float = 0.5,
-        **kwargs: Any,
-    ) -> list[Document]:
-        """Return docs selected using the maximal marginal relevance.
-        Maximal marginal relevance optimizes for similarity to query AND diversity
-        among selected documents.
-        Args:
-            query: Text to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            fetch_k: Number of Documents to fetch to pass to MMR algorithm.
-                     Defaults to 20.
-            lambda_mult: Number between 0 and 1 that determines the degree
-                        of diversity among the results with 0 corresponding
-                        to maximum diversity and 1 to minimum diversity.
-                        Defaults to 0.5.
-        Returns:
-            List of Documents selected by maximal marginal relevance and distance for
-            each.
-        """
-        results = await self.amax_marginal_relevance_search_with_score_by_vector(
-            embedding, k, fetch_k, lambda_mult, **kwargs
-        )
-        return list(map(itemgetter(0), results))
-
-    def max_marginal_relevance_search_with_score_by_vector(
-        self,
-        embedding: list[float],
-        k: int = 4,
-        fetch_k: int = 20,
-        lambda_mult: float = 0.5,
-        **kwargs: Any,
-    ) -> list[tuple[Document, float]]:
-        """Return docs selected using the maximal marginal relevance.
-        Maximal marginal relevance optimizes for similarity to query AND diversity
-        among selected documents.
-        Args:
-            query: Text to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            fetch_k: Number of Documents to fetch to pass to MMR algorithm.
-                     Defaults to 20.
-            lambda_mult: Number between 0 and 1 that determines the degree
-                        of diversity among the results with 0 corresponding
-                        to maximum diversity and 1 to minimum diversity.
-                        Defaults to 0.5.
-        Returns:
-            List of Documents selected by maximal marginal relevance and distance for
-            each.
-        """
-        query_vector = embedding
-        if self.vector_name is not None:
-            query_vector = (self.vector_name, query_vector)  # type: ignore[assignment]
-
-        results = self.client.search(
-            collection_name=self.collection_name,
-            query_vector=query_vector,
-            with_payload=True,
-            with_vectors=True,
-            limit=fetch_k,
-        )
-        embeddings = [
-            result.vector.get(self.vector_name)  # type: ignore[index, union-attr]
-            if self.vector_name is not None
-            else result.vector
-            for result in results
-        ]
-        mmr_selected = maximal_marginal_relevance(
-            np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
-        )
-        return [
-            (
-                self._document_from_scored_point(
-                    results[i], self.content_payload_key, self.metadata_payload_key
-                ),
-                results[i].score,
-            )
-            for i in mmr_selected
-        ]
-
-    @sync_call_fallback
-    async def amax_marginal_relevance_search_with_score_by_vector(
-        self,
-        embedding: list[float],
-        k: int = 4,
-        fetch_k: int = 20,
-        lambda_mult: float = 0.5,
-        **kwargs: Any,
-    ) -> list[tuple[Document, float]]:
-        """Return docs selected using the maximal marginal relevance.
-        Maximal marginal relevance optimizes for similarity to query AND diversity
-        among selected documents.
-        Args:
-            query: Text to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            fetch_k: Number of Documents to fetch to pass to MMR algorithm.
-                     Defaults to 20.
-            lambda_mult: Number between 0 and 1 that determines the degree
-                        of diversity among the results with 0 corresponding
-                        to maximum diversity and 1 to minimum diversity.
-                        Defaults to 0.5.
-        Returns:
-            List of Documents selected by maximal marginal relevance and distance for
-            each.
-        """
-        from qdrant_client import grpc  # noqa
-        from qdrant_client.conversions.conversion import GrpcToRest
-
-        response = await self.client.async_grpc_points.Search(
-            grpc.SearchPoints(
-                collection_name=self.collection_name,
-                vector_name=self.vector_name,
-                vector=embedding,
-                with_payload=grpc.WithPayloadSelector(enable=True),
-                with_vectors=grpc.WithVectorsSelector(enable=True),
-                limit=fetch_k,
-            )
-        )
-        results = [
-            GrpcToRest.convert_vectors(result.vectors) for result in response.result
-        ]
-        embeddings: list[list[float]] = [
-            result.get(self.vector_name)  # type: ignore
-            if isinstance(result, dict)
-            else result
-            for result in results
-        ]
-        mmr_selected: list[int] = maximal_marginal_relevance(
-            np.array(embedding),
-            embeddings,
-            k=k,
-            lambda_mult=lambda_mult,
-        )
-        return [
-            (
-                self._document_from_scored_point_grpc(
-                    response.result[i],
-                    self.content_payload_key,
-                    self.metadata_payload_key,
-                ),
-                response.result[i].score,
-            )
-            for i in mmr_selected
-        ]
-
-    def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> Optional[bool]:
-        """Delete by vector ID or other criteria.
-
-        Args:
-            ids: List of ids to delete.
-            **kwargs: Other keyword arguments that subclasses might use.
-
-        Returns:
-            Optional[bool]: True if deletion is successful,
-            False otherwise, None if not implemented.
-        """
-        from qdrant_client.http import models as rest
-
-        result = self.client.delete(
-            collection_name=self.collection_name,
-            points_selector=ids,
-        )
-        return result.status == rest.UpdateStatus.COMPLETED
-
-    @classmethod
-    def from_texts(
-        cls: type[Qdrant],
-        texts: list[str],
-        embedding: Embeddings,
-        metadatas: Optional[list[dict]] = None,
-        ids: Optional[Sequence[str]] = None,
-        location: Optional[str] = None,
-        url: Optional[str] = None,
-        port: Optional[int] = 6333,
-        grpc_port: int = 6334,
-        prefer_grpc: bool = False,
-        https: Optional[bool] = None,
-        api_key: Optional[str] = None,
-        prefix: Optional[str] = None,
-        timeout: Optional[float] = None,
-        host: Optional[str] = None,
-        path: Optional[str] = None,
-        collection_name: Optional[str] = None,
-        distance_func: str = "Cosine",
-        content_payload_key: str = CONTENT_KEY,
-        metadata_payload_key: str = METADATA_KEY,
-        group_payload_key: str = GROUP_KEY,
-        group_id: str = None,
-        vector_name: Optional[str] = VECTOR_NAME,
-        batch_size: int = 64,
-        shard_number: Optional[int] = None,
-        replication_factor: Optional[int] = None,
-        write_consistency_factor: Optional[int] = None,
-        on_disk_payload: Optional[bool] = None,
-        hnsw_config: Optional[common_types.HnswConfigDiff] = None,
-        optimizers_config: Optional[common_types.OptimizersConfigDiff] = None,
-        wal_config: Optional[common_types.WalConfigDiff] = None,
-        quantization_config: Optional[common_types.QuantizationConfig] = None,
-        init_from: Optional[common_types.InitFrom] = None,
-        force_recreate: bool = False,
-        **kwargs: Any,
-    ) -> Qdrant:
-        """Construct Qdrant wrapper from a list of texts.
-
-        Args:
-            texts: A list of texts to be indexed in Qdrant.
-            embedding: A subclass of `Embeddings`, responsible for text vectorization.
-            metadatas:
-                An optional list of metadata. If provided it has to be of the same
-                length as a list of texts.
-            ids:
-                Optional list of ids to associate with the texts. Ids have to be
-                uuid-like strings.
-            location:
-                If `:memory:` - use in-memory Qdrant instance.
-                If `str` - use it as a `url` parameter.
-                If `None` - fallback to relying on `host` and `port` parameters.
-            url: either host or str of "Optional[scheme], host, Optional[port],
-                Optional[prefix]". Default: `None`
-            port: Port of the REST API interface. Default: 6333
-            grpc_port: Port of the gRPC interface. Default: 6334
-            prefer_grpc:
-                If true - use gPRC interface whenever possible in custom methods.
-                Default: False
-            https: If true - use HTTPS(SSL) protocol. Default: None
-            api_key: API key for authentication in Qdrant Cloud. Default: None
-            prefix:
-                If not None - add prefix to the REST URL path.
-                Example: service/v1 will result in
-                    http://localhost:6333/service/v1/{qdrant-endpoint} for REST API.
-                Default: None
-            timeout:
-                Timeout for REST and gRPC API requests.
-                Default: 5.0 seconds for REST and unlimited for gRPC
-            host:
-                Host name of Qdrant service. If url and host are None, set to
-                'localhost'. Default: None
-            path:
-                Path in which the vectors will be stored while using local mode.
-                Default: None
-            collection_name:
-                Name of the Qdrant collection to be used. If not provided,
-                it will be created randomly. Default: None
-            distance_func:
-                Distance function. One of: "Cosine" / "Euclid" / "Dot".
-                Default: "Cosine"
-            content_payload_key:
-                A payload key used to store the content of the document.
-                Default: "page_content"
-            metadata_payload_key:
-                A payload key used to store the metadata of the document.
-                Default: "metadata"
-            group_payload_key:
-                A payload key used to store the content of the document.
-                Default: "group_id"
-            group_id:
-                collection group id
-            vector_name:
-                Name of the vector to be used internally in Qdrant.
-                Default: None
-            batch_size:
-                How many vectors upload per-request.
-                Default: 64
-            shard_number: Number of shards in collection. Default is 1, minimum is 1.
-            replication_factor:
-                Replication factor for collection. Default is 1, minimum is 1.
-                Defines how many copies of each shard will be created.
-                Have effect only in distributed mode.
-            write_consistency_factor:
-                Write consistency factor for collection. Default is 1, minimum is 1.
-                Defines how many replicas should apply the operation for us to consider
-                it successful. Increasing this number will make the collection more
-                resilient to inconsistencies, but will also make it fail if not enough
-                replicas are available.
-                Does not have any performance impact.
-                Have effect only in distributed mode.
-            on_disk_payload:
-                If true - point`s payload will not be stored in memory.
-                It will be read from the disk every time it is requested.
-                This setting saves RAM by (slightly) increasing the response time.
-                Note: those payload values that are involved in filtering and are
-                indexed - remain in RAM.
-            hnsw_config: Params for HNSW index
-            optimizers_config: Params for optimizer
-            wal_config: Params for Write-Ahead-Log
-            quantization_config:
-                Params for quantization, if None - quantization will be disabled
-            init_from:
-                Use data stored in another collection to initialize this collection
-            force_recreate:
-                Force recreating the collection
-            **kwargs:
-                Additional arguments passed directly into REST client initialization
-
-        This is a user-friendly interface that:
-        1. Creates embeddings, one for each text
-        2. Initializes the Qdrant database as an in-memory docstore by default
-           (and overridable to a remote docstore)
-        3. Adds the text embeddings to the Qdrant database
-
-        This is intended to be a quick way to get started.
-
-        Example:
-            .. code-block:: python
-
-                from langchain import Qdrant
-                from langchain.embeddings import OpenAIEmbeddings
-                embeddings = OpenAIEmbeddings()
-                qdrant = Qdrant.from_texts(texts, embeddings, "localhost")
-        """
-        qdrant = cls._construct_instance(
-            texts,
-            embedding,
-            metadatas,
-            ids,
-            location,
-            url,
-            port,
-            grpc_port,
-            prefer_grpc,
-            https,
-            api_key,
-            prefix,
-            timeout,
-            host,
-            path,
-            collection_name,
-            distance_func,
-            content_payload_key,
-            metadata_payload_key,
-            group_payload_key,
-            group_id,
-            vector_name,
-            shard_number,
-            replication_factor,
-            write_consistency_factor,
-            on_disk_payload,
-            hnsw_config,
-            optimizers_config,
-            wal_config,
-            quantization_config,
-            init_from,
-            force_recreate,
-            **kwargs,
-        )
-        qdrant.add_texts(texts, metadatas, ids, batch_size)
-        return qdrant
-
-    @classmethod
-    @sync_call_fallback
-    async def afrom_texts(
-        cls: type[Qdrant],
-        texts: list[str],
-        embedding: Embeddings,
-        metadatas: Optional[list[dict]] = None,
-        ids: Optional[Sequence[str]] = None,
-        location: Optional[str] = None,
-        url: Optional[str] = None,
-        port: Optional[int] = 6333,
-        grpc_port: int = 6334,
-        prefer_grpc: bool = False,
-        https: Optional[bool] = None,
-        api_key: Optional[str] = None,
-        prefix: Optional[str] = None,
-        timeout: Optional[float] = None,
-        host: Optional[str] = None,
-        path: Optional[str] = None,
-        collection_name: Optional[str] = None,
-        distance_func: str = "Cosine",
-        content_payload_key: str = CONTENT_KEY,
-        metadata_payload_key: str = METADATA_KEY,
-        vector_name: Optional[str] = VECTOR_NAME,
-        batch_size: int = 64,
-        shard_number: Optional[int] = None,
-        replication_factor: Optional[int] = None,
-        write_consistency_factor: Optional[int] = None,
-        on_disk_payload: Optional[bool] = None,
-        hnsw_config: Optional[common_types.HnswConfigDiff] = None,
-        optimizers_config: Optional[common_types.OptimizersConfigDiff] = None,
-        wal_config: Optional[common_types.WalConfigDiff] = None,
-        quantization_config: Optional[common_types.QuantizationConfig] = None,
-        init_from: Optional[common_types.InitFrom] = None,
-        force_recreate: bool = False,
-        **kwargs: Any,
-    ) -> Qdrant:
-        """Construct Qdrant wrapper from a list of texts.
-
-        Args:
-            texts: A list of texts to be indexed in Qdrant.
-            embedding: A subclass of `Embeddings`, responsible for text vectorization.
-            metadatas:
-                An optional list of metadata. If provided it has to be of the same
-                length as a list of texts.
-            ids:
-                Optional list of ids to associate with the texts. Ids have to be
-                uuid-like strings.
-            location:
-                If `:memory:` - use in-memory Qdrant instance.
-                If `str` - use it as a `url` parameter.
-                If `None` - fallback to relying on `host` and `port` parameters.
-            url: either host or str of "Optional[scheme], host, Optional[port],
-                Optional[prefix]". Default: `None`
-            port: Port of the REST API interface. Default: 6333
-            grpc_port: Port of the gRPC interface. Default: 6334
-            prefer_grpc:
-                If true - use gPRC interface whenever possible in custom methods.
-                Default: False
-            https: If true - use HTTPS(SSL) protocol. Default: None
-            api_key: API key for authentication in Qdrant Cloud. Default: None
-            prefix:
-                If not None - add prefix to the REST URL path.
-                Example: service/v1 will result in
-                    http://localhost:6333/service/v1/{qdrant-endpoint} for REST API.
-                Default: None
-            timeout:
-                Timeout for REST and gRPC API requests.
-                Default: 5.0 seconds for REST and unlimited for gRPC
-            host:
-                Host name of Qdrant service. If url and host are None, set to
-                'localhost'. Default: None
-            path:
-                Path in which the vectors will be stored while using local mode.
-                Default: None
-            collection_name:
-                Name of the Qdrant collection to be used. If not provided,
-                it will be created randomly. Default: None
-            distance_func:
-                Distance function. One of: "Cosine" / "Euclid" / "Dot".
-                Default: "Cosine"
-            content_payload_key:
-                A payload key used to store the content of the document.
-                Default: "page_content"
-            metadata_payload_key:
-                A payload key used to store the metadata of the document.
-                Default: "metadata"
-            vector_name:
-                Name of the vector to be used internally in Qdrant.
-                Default: None
-            batch_size:
-                How many vectors upload per-request.
-                Default: 64
-            shard_number: Number of shards in collection. Default is 1, minimum is 1.
-            replication_factor:
-                Replication factor for collection. Default is 1, minimum is 1.
-                Defines how many copies of each shard will be created.
-                Have effect only in distributed mode.
-            write_consistency_factor:
-                Write consistency factor for collection. Default is 1, minimum is 1.
-                Defines how many replicas should apply the operation for us to consider
-                it successful. Increasing this number will make the collection more
-                resilient to inconsistencies, but will also make it fail if not enough
-                replicas are available.
-                Does not have any performance impact.
-                Have effect only in distributed mode.
-            on_disk_payload:
-                If true - point`s payload will not be stored in memory.
-                It will be read from the disk every time it is requested.
-                This setting saves RAM by (slightly) increasing the response time.
-                Note: those payload values that are involved in filtering and are
-                indexed - remain in RAM.
-            hnsw_config: Params for HNSW index
-            optimizers_config: Params for optimizer
-            wal_config: Params for Write-Ahead-Log
-            quantization_config:
-                Params for quantization, if None - quantization will be disabled
-            init_from:
-                Use data stored in another collection to initialize this collection
-            force_recreate:
-                Force recreating the collection
-            **kwargs:
-                Additional arguments passed directly into REST client initialization
-
-        This is a user-friendly interface that:
-        1. Creates embeddings, one for each text
-        2. Initializes the Qdrant database as an in-memory docstore by default
-           (and overridable to a remote docstore)
-        3. Adds the text embeddings to the Qdrant database
-
-        This is intended to be a quick way to get started.
-
-        Example:
-            .. code-block:: python
-
-                from langchain import Qdrant
-                from langchain.embeddings import OpenAIEmbeddings
-                embeddings = OpenAIEmbeddings()
-                qdrant = await Qdrant.afrom_texts(texts, embeddings, "localhost")
-        """
-        qdrant = cls._construct_instance(
-            texts,
-            embedding,
-            metadatas,
-            ids,
-            location,
-            url,
-            port,
-            grpc_port,
-            prefer_grpc,
-            https,
-            api_key,
-            prefix,
-            timeout,
-            host,
-            path,
-            collection_name,
-            distance_func,
-            content_payload_key,
-            metadata_payload_key,
-            vector_name,
-            shard_number,
-            replication_factor,
-            write_consistency_factor,
-            on_disk_payload,
-            hnsw_config,
-            optimizers_config,
-            wal_config,
-            quantization_config,
-            init_from,
-            force_recreate,
-            **kwargs,
-        )
-        await qdrant.aadd_texts(texts, metadatas, ids, batch_size)
-        return qdrant
-
-    @classmethod
-    def _construct_instance(
-        cls: type[Qdrant],
-        texts: list[str],
-        embedding: Embeddings,
-        metadatas: Optional[list[dict]] = None,
-        ids: Optional[Sequence[str]] = None,
-        location: Optional[str] = None,
-        url: Optional[str] = None,
-        port: Optional[int] = 6333,
-        grpc_port: int = 6334,
-        prefer_grpc: bool = False,
-        https: Optional[bool] = None,
-        api_key: Optional[str] = None,
-        prefix: Optional[str] = None,
-        timeout: Optional[float] = None,
-        host: Optional[str] = None,
-        path: Optional[str] = None,
-        collection_name: Optional[str] = None,
-        distance_func: str = "Cosine",
-        content_payload_key: str = CONTENT_KEY,
-        metadata_payload_key: str = METADATA_KEY,
-        group_payload_key: str = GROUP_KEY,
-        group_id: str = None,
-        vector_name: Optional[str] = VECTOR_NAME,
-        shard_number: Optional[int] = None,
-        replication_factor: Optional[int] = None,
-        write_consistency_factor: Optional[int] = None,
-        on_disk_payload: Optional[bool] = None,
-        hnsw_config: Optional[common_types.HnswConfigDiff] = None,
-        optimizers_config: Optional[common_types.OptimizersConfigDiff] = None,
-        wal_config: Optional[common_types.WalConfigDiff] = None,
-        quantization_config: Optional[common_types.QuantizationConfig] = None,
-        init_from: Optional[common_types.InitFrom] = None,
-        force_recreate: bool = False,
-        **kwargs: Any,
-    ) -> Qdrant:
-        try:
-            import qdrant_client
-        except ImportError:
-            raise ValueError(
-                "Could not import qdrant-client python package. "
-                "Please install it with `pip install qdrant-client`."
-            )
-        from qdrant_client.http import models as rest
-
-        # Just do a single quick embedding to get vector size
-        partial_embeddings = embedding.embed_documents(texts[:1])
-        vector_size = len(partial_embeddings[0])
-        collection_name = collection_name or uuid.uuid4().hex
-        distance_func = distance_func.upper()
-        is_new_collection = False
-        client = qdrant_client.QdrantClient(
-            location=location,
-            url=url,
-            port=port,
-            grpc_port=grpc_port,
-            prefer_grpc=prefer_grpc,
-            https=https,
-            api_key=api_key,
-            prefix=prefix,
-            timeout=timeout,
-            host=host,
-            path=path,
-            **kwargs,
-        )
-        all_collection_name = []
-        collections_response = client.get_collections()
-        collection_list = collections_response.collections
-        for collection in collection_list:
-            all_collection_name.append(collection.name)
-        if collection_name not in all_collection_name:
-            vectors_config = rest.VectorParams(
-                size=vector_size,
-                distance=rest.Distance[distance_func],
-            )
-
-            # If vector name was provided, we're going to use the named vectors feature
-            # with just a single vector.
-            if vector_name is not None:
-                vectors_config = {  # type: ignore[assignment]
-                    vector_name: vectors_config,
-                }
-
-            client.recreate_collection(
-                collection_name=collection_name,
-                vectors_config=vectors_config,
-                shard_number=shard_number,
-                replication_factor=replication_factor,
-                write_consistency_factor=write_consistency_factor,
-                on_disk_payload=on_disk_payload,
-                hnsw_config=hnsw_config,
-                optimizers_config=optimizers_config,
-                wal_config=wal_config,
-                quantization_config=quantization_config,
-                init_from=init_from,
-                timeout=int(timeout),  # type: ignore[arg-type]
-            )
-            is_new_collection = True
-        if force_recreate:
-            raise ValueError
-
-        # Get the vector configuration of the existing collection and vector, if it
-        # was specified. If the old configuration does not match the current one,
-        # an exception is being thrown.
-        collection_info = client.get_collection(collection_name=collection_name)
-        current_vector_config = collection_info.config.params.vectors
-        if isinstance(current_vector_config, dict) and vector_name is not None:
-            if vector_name not in current_vector_config:
-                raise QdrantException(
-                    f"Existing Qdrant collection {collection_name} does not "
-                    f"contain vector named {vector_name}. Did you mean one of the "
-                    f"existing vectors: {', '.join(current_vector_config.keys())}? "
-                    f"If you want to recreate the collection, set `force_recreate` "
-                    f"parameter to `True`."
-                )
-            current_vector_config = current_vector_config.get(
-                vector_name
-            )  # type: ignore[assignment]
-        elif isinstance(current_vector_config, dict) and vector_name is None:
-            raise QdrantException(
-                f"Existing Qdrant collection {collection_name} uses named vectors. "
-                f"If you want to reuse it, please set `vector_name` to any of the "
-                f"existing named vectors: "
-                f"{', '.join(current_vector_config.keys())}."  # noqa
-                f"If you want to recreate the collection, set `force_recreate` "
-                f"parameter to `True`."
-            )
-        elif (
-                not isinstance(current_vector_config, dict) and vector_name is not None
-        ):
-            raise QdrantException(
-                f"Existing Qdrant collection {collection_name} doesn't use named "
-                f"vectors. If you want to reuse it, please set `vector_name` to "
-                f"`None`. If you want to recreate the collection, set "
-                f"`force_recreate` parameter to `True`."
-            )
-
-        # Check if the vector configuration has the same dimensionality.
-        if current_vector_config.size != vector_size:  # type: ignore[union-attr]
-            raise QdrantException(
-                f"Existing Qdrant collection is configured for vectors with "
-                f"{current_vector_config.size} "  # type: ignore[union-attr]
-                f"dimensions. Selected embeddings are {vector_size}-dimensional. "
-                f"If you want to recreate the collection, set `force_recreate` "
-                f"parameter to `True`."
-            )
-
-        current_distance_func = (
-            current_vector_config.distance.name.upper()  # type: ignore[union-attr]
-        )
-        if current_distance_func != distance_func:
-            raise QdrantException(
-                f"Existing Qdrant collection is configured for "
-                f"{current_vector_config.distance} "  # type: ignore[union-attr]
-                f"similarity. Please set `distance_func` parameter to "
-                f"`{distance_func}` if you want to reuse it. If you want to "
-                f"recreate the collection, set `force_recreate` parameter to "
-                f"`True`."
-            )
-        qdrant = cls(
-            client=client,
-            collection_name=collection_name,
-            embeddings=embedding,
-            content_payload_key=content_payload_key,
-            metadata_payload_key=metadata_payload_key,
-            distance_strategy=distance_func,
-            vector_name=vector_name,
-            group_id=group_id,
-            group_payload_key=group_payload_key,
-            is_new_collection=is_new_collection
-        )
-        return qdrant
-
-    def _select_relevance_score_fn(self) -> Callable[[float], float]:
-        """
-        The 'correct' relevance function
-        may differ depending on a few things, including:
-        - the distance / similarity metric used by the VectorStore
-        - the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
-        - embedding dimensionality
-        - etc.
-        """
-
-        if self.distance_strategy == "COSINE":
-            return self._cosine_relevance_score_fn
-        elif self.distance_strategy == "DOT":
-            return self._max_inner_product_relevance_score_fn
-        elif self.distance_strategy == "EUCLID":
-            return self._euclidean_relevance_score_fn
-        else:
-            raise ValueError(
-                "Unknown distance strategy, must be cosine, "
-                "max_inner_product, or euclidean"
-            )
-
-    def _similarity_search_with_relevance_scores(
-        self,
-        query: str,
-        k: int = 4,
-        **kwargs: Any,
-    ) -> list[tuple[Document, float]]:
-        """Return docs and relevance scores in the range [0, 1].
-
-        0 is dissimilar, 1 is most similar.
-
-        Args:
-            query: input text
-            k: Number of Documents to return. Defaults to 4.
-            **kwargs: kwargs to be passed to similarity search. Should include:
-                score_threshold: Optional, a floating point value between 0 to 1 to
-                    filter the resulting set of retrieved docs
-
-        Returns:
-            List of Tuples of (doc, similarity_score)
-        """
-        return self.similarity_search_with_score(query, k, **kwargs)
-
-    @classmethod
-    def _build_payloads(
-        cls,
-        texts: Iterable[str],
-        metadatas: Optional[list[dict]],
-        content_payload_key: str,
-        metadata_payload_key: str,
-        group_id: str,
-        group_payload_key: str
-    ) -> list[dict]:
-        payloads = []
-        for i, text in enumerate(texts):
-            if text is None:
-                raise ValueError(
-                    "At least one of the texts is None. Please remove it before "
-                    "calling .from_texts or .add_texts on Qdrant instance."
-                )
-            metadata = metadatas[i] if metadatas is not None else None
-            payloads.append(
-                {
-                    content_payload_key: text,
-                    metadata_payload_key: metadata,
-                    group_payload_key: group_id
-                }
-            )
-
-        return payloads
-
-    @classmethod
-    def _document_from_scored_point(
-        cls,
-        scored_point: Any,
-        content_payload_key: str,
-        metadata_payload_key: str,
-    ) -> Document:
-        return Document(
-            page_content=scored_point.payload.get(content_payload_key),
-            metadata=scored_point.payload.get(metadata_payload_key) or {},
-        )
-
-    @classmethod
-    def _document_from_scored_point_grpc(
-        cls,
-        scored_point: Any,
-        content_payload_key: str,
-        metadata_payload_key: str,
-    ) -> Document:
-        from qdrant_client.conversions.conversion import grpc_to_payload
-
-        payload = grpc_to_payload(scored_point.payload)
-        return Document(
-            page_content=payload[content_payload_key],
-            metadata=payload.get(metadata_payload_key) or {},
-        )
-
-    def _build_condition(self, key: str, value: Any) -> list[rest.FieldCondition]:
-        from qdrant_client.http import models as rest
-
-        out = []
-
-        if isinstance(value, dict):
-            for _key, value in value.items():
-                out.extend(self._build_condition(f"{key}.{_key}", value))
-        elif isinstance(value, list):
-            for _value in value:
-                if isinstance(_value, dict):
-                    out.extend(self._build_condition(f"{key}[]", _value))
-                else:
-                    out.extend(self._build_condition(f"{key}", _value))
-        else:
-            out.append(
-                rest.FieldCondition(
-                    key=key,
-                    match=rest.MatchValue(value=value),
-                )
-            )
-
-        return out
-
-    def _qdrant_filter_from_dict(
-        self, filter: Optional[DictFilter]
-    ) -> Optional[rest.Filter]:
-        from qdrant_client.http import models as rest
-
-        if not filter:
-            return None
-
-        return rest.Filter(
-            must=[
-                condition
-                for key, value in filter.items()
-                for condition in self._build_condition(key, value)
-            ]
-        )
-
-    def _embed_query(self, query: str) -> list[float]:
-        """Embed query text.
-
-        Used to provide backward compatibility with `embedding_function` argument.
-
-        Args:
-            query: Query text.
-
-        Returns:
-            List of floats representing the query embedding.
-        """
-        if self.embeddings is not None:
-            embedding = self.embeddings.embed_query(query)
-        else:
-            if self._embeddings_function is not None:
-                embedding = self._embeddings_function(query)
-            else:
-                raise ValueError("Neither of embeddings or embedding_function is set")
-        return embedding.tolist() if hasattr(embedding, "tolist") else embedding
-
-    def _embed_texts(self, texts: Iterable[str]) -> list[list[float]]:
-        """Embed search texts.
-
-        Used to provide backward compatibility with `embedding_function` argument.
-
-        Args:
-            texts: Iterable of texts to embed.
-
-        Returns:
-            List of floats representing the texts embedding.
-        """
-        if self.embeddings is not None:
-            embeddings = self.embeddings.embed_documents(list(texts))
-            if hasattr(embeddings, "tolist"):
-                embeddings = embeddings.tolist()
-        elif self._embeddings_function is not None:
-            embeddings = []
-            for text in texts:
-                embedding = self._embeddings_function(text)
-                if hasattr(embeddings, "tolist"):
-                    embedding = embedding.tolist()
-                embeddings.append(embedding)
-        else:
-            raise ValueError("Neither of embeddings or embedding_function is set")
-
-        return embeddings
-
-    def _generate_rest_batches(
-        self,
-        texts: Iterable[str],
-        metadatas: Optional[list[dict]] = None,
-        ids: Optional[Sequence[str]] = None,
-        batch_size: int = 64,
-        group_id: Optional[str] = None,
-    ) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]:
-        from qdrant_client.http import models as rest
-
-        texts_iterator = iter(texts)
-        metadatas_iterator = iter(metadatas or [])
-        ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
-        while batch_texts := list(islice(texts_iterator, batch_size)):
-            # Take the corresponding metadata and id for each text in a batch
-            batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
-            batch_ids = list(islice(ids_iterator, batch_size))
-
-            # Generate the embeddings for all the texts in a batch
-            batch_embeddings = self._embed_texts(batch_texts)
-
-            points = [
-                rest.PointStruct(
-                    id=point_id,
-                    vector=vector
-                    if self.vector_name is None
-                    else {self.vector_name: vector},
-                    payload=payload,
-                )
-                for point_id, vector, payload in zip(
-                    batch_ids,
-                    batch_embeddings,
-                    self._build_payloads(
-                        batch_texts,
-                        batch_metadatas,
-                        self.content_payload_key,
-                        self.metadata_payload_key,
-                        self.group_id,
-                        self.group_payload_key
-                    ),
-                )
-            ]
-
-            yield batch_ids, points

+ 0 - 506
api/core/vector_store/vector/weaviate.py

@@ -1,506 +0,0 @@
-"""Wrapper around weaviate vector database."""
-from __future__ import annotations
-
-import datetime
-from collections.abc import Callable, Iterable
-from typing import Any, Optional
-from uuid import uuid4
-
-import numpy as np
-from langchain.docstore.document import Document
-from langchain.embeddings.base import Embeddings
-from langchain.utils import get_from_dict_or_env
-from langchain.vectorstores.base import VectorStore
-from langchain.vectorstores.utils import maximal_marginal_relevance
-
-
-def _default_schema(index_name: str) -> dict:
-    return {
-        "class": index_name,
-        "properties": [
-            {
-                "name": "text",
-                "dataType": ["text"],
-            }
-        ],
-    }
-
-
-def _create_weaviate_client(**kwargs: Any) -> Any:
-    client = kwargs.get("client")
-    if client is not None:
-        return client
-
-    weaviate_url = get_from_dict_or_env(kwargs, "weaviate_url", "WEAVIATE_URL")
-
-    try:
-        # the weaviate api key param should not be mandatory
-        weaviate_api_key = get_from_dict_or_env(
-            kwargs, "weaviate_api_key", "WEAVIATE_API_KEY", None
-        )
-    except ValueError:
-        weaviate_api_key = None
-
-    try:
-        import weaviate
-    except ImportError:
-        raise ValueError(
-            "Could not import weaviate python  package. "
-            "Please install it with `pip install weaviate-client`"
-        )
-
-    auth = (
-        weaviate.auth.AuthApiKey(api_key=weaviate_api_key)
-        if weaviate_api_key is not None
-        else None
-    )
-    client = weaviate.Client(weaviate_url, auth_client_secret=auth)
-
-    return client
-
-
-def _default_score_normalizer(val: float) -> float:
-    return 1 - val
-
-
-def _json_serializable(value: Any) -> Any:
-    if isinstance(value, datetime.datetime):
-        return value.isoformat()
-    return value
-
-
-class Weaviate(VectorStore):
-    """Wrapper around Weaviate vector database.
-
-    To use, you should have the ``weaviate-client`` python package installed.
-
-    Example:
-        .. code-block:: python
-
-            import weaviate
-            from langchain.vectorstores import Weaviate
-            client = weaviate.Client(url=os.environ["WEAVIATE_URL"], ...)
-            weaviate = Weaviate(client, index_name, text_key)
-
-    """
-
-    def __init__(
-        self,
-        client: Any,
-        index_name: str,
-        text_key: str,
-        embedding: Optional[Embeddings] = None,
-        attributes: Optional[list[str]] = None,
-        relevance_score_fn: Optional[
-            Callable[[float], float]
-        ] = _default_score_normalizer,
-        by_text: bool = True,
-    ):
-        """Initialize with Weaviate client."""
-        try:
-            import weaviate
-        except ImportError:
-            raise ValueError(
-                "Could not import weaviate python package. "
-                "Please install it with `pip install weaviate-client`."
-            )
-        if not isinstance(client, weaviate.Client):
-            raise ValueError(
-                f"client should be an instance of weaviate.Client, got {type(client)}"
-            )
-        self._client = client
-        self._index_name = index_name
-        self._embedding = embedding
-        self._text_key = text_key
-        self._query_attrs = [self._text_key]
-        self.relevance_score_fn = relevance_score_fn
-        self._by_text = by_text
-        if attributes is not None:
-            self._query_attrs.extend(attributes)
-
-    @property
-    def embeddings(self) -> Optional[Embeddings]:
-        return self._embedding
-
-    def _select_relevance_score_fn(self) -> Callable[[float], float]:
-        return (
-            self.relevance_score_fn
-            if self.relevance_score_fn
-            else _default_score_normalizer
-        )
-
-    def add_texts(
-        self,
-        texts: Iterable[str],
-        metadatas: Optional[list[dict]] = None,
-        **kwargs: Any,
-    ) -> list[str]:
-        """Upload texts with metadata (properties) to Weaviate."""
-        from weaviate.util import get_valid_uuid
-
-        ids = []
-        embeddings: Optional[list[list[float]]] = None
-        if self._embedding:
-            if not isinstance(texts, list):
-                texts = list(texts)
-            embeddings = self._embedding.embed_documents(texts)
-
-        with self._client.batch as batch:
-            for i, text in enumerate(texts):
-                data_properties = {self._text_key: text}
-                if metadatas is not None:
-                    for key, val in metadatas[i].items():
-                        data_properties[key] = _json_serializable(val)
-
-                # Allow for ids (consistent w/ other methods)
-                # # Or uuids (backwards compatble w/ existing arg)
-                # If the UUID of one of the objects already exists
-                # then the existing object will be replaced by the new object.
-                _id = get_valid_uuid(uuid4())
-                if "uuids" in kwargs:
-                    _id = kwargs["uuids"][i]
-                elif "ids" in kwargs:
-                    _id = kwargs["ids"][i]
-
-                batch.add_data_object(
-                    data_object=data_properties,
-                    class_name=self._index_name,
-                    uuid=_id,
-                    vector=embeddings[i] if embeddings else None,
-                )
-                ids.append(_id)
-        return ids
-
-    def similarity_search(
-        self, query: str, k: int = 4, **kwargs: Any
-    ) -> list[Document]:
-        """Return docs most similar to query.
-
-        Args:
-            query: Text to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-
-        Returns:
-            List of Documents most similar to the query.
-        """
-        if self._by_text:
-            return self.similarity_search_by_text(query, k, **kwargs)
-        else:
-            if self._embedding is None:
-                raise ValueError(
-                    "_embedding cannot be None for similarity_search when "
-                    "_by_text=False"
-                )
-            embedding = self._embedding.embed_query(query)
-            return self.similarity_search_by_vector(embedding, k, **kwargs)
-
-    def similarity_search_by_text(
-        self, query: str, k: int = 4, **kwargs: Any
-    ) -> list[Document]:
-        """Return docs most similar to query.
-
-        Args:
-            query: Text to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-
-        Returns:
-            List of Documents most similar to the query.
-        """
-        content: dict[str, Any] = {"concepts": [query]}
-        if kwargs.get("search_distance"):
-            content["certainty"] = kwargs.get("search_distance")
-        query_obj = self._client.query.get(self._index_name, self._query_attrs)
-        if kwargs.get("where_filter"):
-            query_obj = query_obj.with_where(kwargs.get("where_filter"))
-        if kwargs.get("additional"):
-            query_obj = query_obj.with_additional(kwargs.get("additional"))
-        result = query_obj.with_near_text(content).with_limit(k).do()
-        if "errors" in result:
-            raise ValueError(f"Error during query: {result['errors']}")
-        docs = []
-        for res in result["data"]["Get"][self._index_name]:
-            text = res.pop(self._text_key)
-            docs.append(Document(page_content=text, metadata=res))
-        return docs
-
-    def similarity_search_by_bm25(
-        self, query: str, k: int = 4, **kwargs: Any
-    ) -> list[Document]:
-        """Return docs using BM25F.
-
-        Args:
-            query: Text to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-
-        Returns:
-            List of Documents most similar to the query.
-        """
-        content: dict[str, Any] = {"concepts": [query]}
-        if kwargs.get("search_distance"):
-            content["certainty"] = kwargs.get("search_distance")
-        query_obj = self._client.query.get(self._index_name, self._query_attrs)
-        if kwargs.get("where_filter"):
-            query_obj = query_obj.with_where(kwargs.get("where_filter"))
-        if kwargs.get("additional"):
-            query_obj = query_obj.with_additional(kwargs.get("additional"))
-        properties = ['text']
-        result = query_obj.with_bm25(query=query, properties=properties).with_limit(k).do()
-        if "errors" in result:
-            raise ValueError(f"Error during query: {result['errors']}")
-        docs = []
-        for res in result["data"]["Get"][self._index_name]:
-            text = res.pop(self._text_key)
-            docs.append(Document(page_content=text, metadata=res))
-        return docs
-
-    def similarity_search_by_vector(
-        self, embedding: list[float], k: int = 4, **kwargs: Any
-    ) -> list[Document]:
-        """Look up similar documents by embedding vector in Weaviate."""
-        vector = {"vector": embedding}
-        query_obj = self._client.query.get(self._index_name, self._query_attrs)
-        if kwargs.get("where_filter"):
-            query_obj = query_obj.with_where(kwargs.get("where_filter"))
-        if kwargs.get("additional"):
-            query_obj = query_obj.with_additional(kwargs.get("additional"))
-        result = query_obj.with_near_vector(vector).with_limit(k).do()
-        if "errors" in result:
-            raise ValueError(f"Error during query: {result['errors']}")
-        docs = []
-        for res in result["data"]["Get"][self._index_name]:
-            text = res.pop(self._text_key)
-            docs.append(Document(page_content=text, metadata=res))
-        return docs
-
-    def max_marginal_relevance_search(
-        self,
-        query: str,
-        k: int = 4,
-        fetch_k: int = 20,
-        lambda_mult: float = 0.5,
-        **kwargs: Any,
-    ) -> list[Document]:
-        """Return docs selected using the maximal marginal relevance.
-
-        Maximal marginal relevance optimizes for similarity to query AND diversity
-        among selected documents.
-
-        Args:
-            query: Text to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            fetch_k: Number of Documents to fetch to pass to MMR algorithm.
-            lambda_mult: Number between 0 and 1 that determines the degree
-                        of diversity among the results with 0 corresponding
-                        to maximum diversity and 1 to minimum diversity.
-                        Defaults to 0.5.
-
-        Returns:
-            List of Documents selected by maximal marginal relevance.
-        """
-        if self._embedding is not None:
-            embedding = self._embedding.embed_query(query)
-        else:
-            raise ValueError(
-                "max_marginal_relevance_search requires a suitable Embeddings object"
-            )
-
-        return self.max_marginal_relevance_search_by_vector(
-            embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, **kwargs
-        )
-
-    def max_marginal_relevance_search_by_vector(
-        self,
-        embedding: list[float],
-        k: int = 4,
-        fetch_k: int = 20,
-        lambda_mult: float = 0.5,
-        **kwargs: Any,
-    ) -> list[Document]:
-        """Return docs selected using the maximal marginal relevance.
-
-        Maximal marginal relevance optimizes for similarity to query AND diversity
-        among selected documents.
-
-        Args:
-            embedding: Embedding to look up documents similar to.
-            k: Number of Documents to return. Defaults to 4.
-            fetch_k: Number of Documents to fetch to pass to MMR algorithm.
-            lambda_mult: Number between 0 and 1 that determines the degree
-                        of diversity among the results with 0 corresponding
-                        to maximum diversity and 1 to minimum diversity.
-                        Defaults to 0.5.
-
-        Returns:
-            List of Documents selected by maximal marginal relevance.
-        """
-        vector = {"vector": embedding}
-        query_obj = self._client.query.get(self._index_name, self._query_attrs)
-        if kwargs.get("where_filter"):
-            query_obj = query_obj.with_where(kwargs.get("where_filter"))
-        results = (
-            query_obj.with_additional("vector")
-            .with_near_vector(vector)
-            .with_limit(fetch_k)
-            .do()
-        )
-
-        payload = results["data"]["Get"][self._index_name]
-        embeddings = [result["_additional"]["vector"] for result in payload]
-        mmr_selected = maximal_marginal_relevance(
-            np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
-        )
-
-        docs = []
-        for idx in mmr_selected:
-            text = payload[idx].pop(self._text_key)
-            payload[idx].pop("_additional")
-            meta = payload[idx]
-            docs.append(Document(page_content=text, metadata=meta))
-        return docs
-
-    def similarity_search_with_score(
-        self, query: str, k: int = 4, **kwargs: Any
-    ) -> list[tuple[Document, float]]:
-        """
-        Return list of documents most similar to the query
-        text and cosine distance in float for each.
-        Lower score represents more similarity.
-        """
-        if self._embedding is None:
-            raise ValueError(
-                "_embedding cannot be None for similarity_search_with_score"
-            )
-        content: dict[str, Any] = {"concepts": [query]}
-        if kwargs.get("search_distance"):
-            content["certainty"] = kwargs.get("search_distance")
-        query_obj = self._client.query.get(self._index_name, self._query_attrs)
-
-        embedded_query = self._embedding.embed_query(query)
-        if not self._by_text:
-            vector = {"vector": embedded_query}
-            result = (
-                query_obj.with_near_vector(vector)
-                .with_limit(k)
-                .with_additional(["vector", "distance"])
-                .do()
-            )
-        else:
-            result = (
-                query_obj.with_near_text(content)
-                .with_limit(k)
-                .with_additional(["vector", "distance"])
-                .do()
-            )
-
-        if "errors" in result:
-            raise ValueError(f"Error during query: {result['errors']}")
-
-        docs_and_scores = []
-        for res in result["data"]["Get"][self._index_name]:
-            text = res.pop(self._text_key)
-            score = res["_additional"]["distance"]
-            docs_and_scores.append((Document(page_content=text, metadata=res), score))
-        return docs_and_scores
-
-    @classmethod
-    def from_texts(
-        cls: type[Weaviate],
-        texts: list[str],
-        embedding: Embeddings,
-        metadatas: Optional[list[dict]] = None,
-        **kwargs: Any,
-    ) -> Weaviate:
-        """Construct Weaviate wrapper from raw documents.
-
-        This is a user-friendly interface that:
-            1. Embeds documents.
-            2. Creates a new index for the embeddings in the Weaviate instance.
-            3. Adds the documents to the newly created Weaviate index.
-
-        This is intended to be a quick way to get started.
-
-        Example:
-            .. code-block:: python
-
-                from langchain.vectorstores.weaviate import Weaviate
-                from langchain.embeddings import OpenAIEmbeddings
-                embeddings = OpenAIEmbeddings()
-                weaviate = Weaviate.from_texts(
-                    texts,
-                    embeddings,
-                    weaviate_url="http://localhost:8080"
-                )
-        """
-
-        client = _create_weaviate_client(**kwargs)
-
-        from weaviate.util import get_valid_uuid
-
-        index_name = kwargs.get("index_name", f"LangChain_{uuid4().hex}")
-        embeddings = embedding.embed_documents(texts) if embedding else None
-        text_key = "text"
-        schema = _default_schema(index_name)
-        attributes = list(metadatas[0].keys()) if metadatas else None
-
-        # check whether the index already exists
-        if not client.schema.contains(schema):
-            client.schema.create_class(schema)
-
-        with client.batch as batch:
-            for i, text in enumerate(texts):
-                data_properties = {
-                    text_key: text,
-                }
-                if metadatas is not None:
-                    for key in metadatas[i].keys():
-                        data_properties[key] = metadatas[i][key]
-
-                # If the UUID of one of the objects already exists
-                # then the existing objectwill be replaced by the new object.
-                if "uuids" in kwargs:
-                    _id = kwargs["uuids"][i]
-                else:
-                    _id = get_valid_uuid(uuid4())
-
-                # if an embedding strategy is not provided, we let
-                # weaviate create the embedding. Note that this will only
-                # work if weaviate has been installed with a vectorizer module
-                # like text2vec-contextionary for example
-                params = {
-                    "uuid": _id,
-                    "data_object": data_properties,
-                    "class_name": index_name,
-                }
-                if embeddings is not None:
-                    params["vector"] = embeddings[i]
-
-                batch.add_data_object(**params)
-
-            batch.flush()
-
-        relevance_score_fn = kwargs.get("relevance_score_fn")
-        by_text: bool = kwargs.get("by_text", False)
-
-        return cls(
-            client,
-            index_name,
-            text_key,
-            embedding=embedding,
-            attributes=attributes,
-            relevance_score_fn=relevance_score_fn,
-            by_text=by_text,
-        )
-
-    def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> None:
-        """Delete by vector IDs.
-
-        Args:
-            ids: List of ids to delete.
-        """
-
-        if ids is None:
-            raise ValueError("No ids provided to delete.")
-
-        # TODO: Check if this can be done in bulk
-        for id in ids:
-            self._client.data_object.delete(uuid=id)

+ 0 - 38
api/core/vector_store/weaviate_vector_store.py

@@ -1,38 +0,0 @@
-from core.vector_store.vector.weaviate import Weaviate
-
-
-class WeaviateVectorStore(Weaviate):
-    def del_texts(self, where_filter: dict):
-        if not where_filter:
-            raise ValueError('where_filter must not be empty')
-
-        self._client.batch.delete_objects(
-            class_name=self._index_name,
-            where=where_filter,
-            output='minimal'
-        )
-
-    def del_text(self, uuid: str) -> None:
-        self._client.data_object.delete(
-            uuid,
-            class_name=self._index_name
-        )
-
-    def text_exists(self, uuid: str) -> bool:
-        result = self._client.query.get(self._index_name).with_additional(["id"]).with_where({
-            "path": ["doc_id"],
-            "operator": "Equal",
-            "valueText": uuid,
-        }).with_limit(1).do()
-
-        if "errors" in result:
-            raise ValueError(f"Error during query: {result['errors']}")
-
-        entries = result["data"]["Get"][self._index_name]
-        if len(entries) == 0:
-            return False
-
-        return True
-
-    def delete(self):
-        self._client.schema.delete_class(self._index_name)

+ 1 - 1
api/events/event_handlers/clean_when_dataset_deleted.py

@@ -6,4 +6,4 @@ from tasks.clean_dataset_task import clean_dataset_task
 def handle(sender, **kwargs):
     dataset = sender
     clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique,
-                             dataset.index_struct, dataset.collection_binding_id)
+                             dataset.index_struct, dataset.collection_binding_id, dataset.doc_form)

+ 2 - 1
api/events/event_handlers/clean_when_document_deleted.py

@@ -6,4 +6,5 @@ from tasks.clean_document_task import clean_document_task
 def handle(sender, **kwargs):
     document_id = sender
     dataset_id = kwargs.get('dataset_id')
-    clean_document_task.delay(document_id, dataset_id)
+    doc_form = kwargs.get('doc_form')
+    clean_document_task.delay(document_id, dataset_id, doc_form)

+ 8 - 0
api/models/dataset.py

@@ -94,6 +94,14 @@ class Dataset(db.Model):
         return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \
             .filter(Document.dataset_id == self.id).scalar()
 
+    @property
+    def doc_form(self):
+        document = db.session.query(Document).filter(
+            Document.dataset_id == self.id).first()
+        if document:
+            return document.doc_form
+        return None
+
     @property
     def retrieval_model_dict(self):
         default_retrieval_model = {

+ 4 - 13
api/schedule/clean_unused_datasets_task.py

@@ -6,7 +6,7 @@ from flask import current_app
 from werkzeug.exceptions import NotFound
 
 import app
-from core.index.index import IndexBuilder
+from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
 from extensions.ext_database import db
 from models.dataset import Dataset, DatasetQuery, Document
 
@@ -41,18 +41,9 @@ def clean_unused_datasets_task():
                 if not documents or len(documents) == 0:
                     try:
                         # remove index
-                        vector_index = IndexBuilder.get_index(dataset, 'high_quality')
-                        kw_index = IndexBuilder.get_index(dataset, 'economy')
-                        # delete from vector index
-                        if vector_index:
-                            if dataset.collection_binding_id:
-                                vector_index.delete_by_group_id(dataset.id)
-                            else:
-                                if dataset.collection_binding_id:
-                                    vector_index.delete_by_group_id(dataset.id)
-                                else:
-                                    vector_index.delete()
-                        kw_index.delete()
+                        index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
+                        index_processor.clean(dataset, None)
+
                         # update document
                         update_params = {
                             Document.enabled: False

+ 21 - 14
api/services/dataset_service.py

@@ -11,10 +11,11 @@ from flask_login import current_user
 from sqlalchemy import func
 
 from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
-from core.index.index import IndexBuilder
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+from core.rag.datasource.keyword.keyword_factory import Keyword
+from core.rag.models.document import Document as RAGDocument
 from events.dataset_event import dataset_was_deleted
 from events.document_event import document_was_deleted
 from extensions.ext_database import db
@@ -402,7 +403,7 @@ class DocumentService:
     @staticmethod
     def delete_document(document):
         # trigger document_was_deleted signal
-        document_was_deleted.send(document.id, dataset_id=document.dataset_id)
+        document_was_deleted.send(document.id, dataset_id=document.dataset_id, doc_form=document.doc_form)
 
         db.session.delete(document)
         db.session.commit()
@@ -1060,7 +1061,7 @@ class SegmentService:
 
         # save vector index
         try:
-            VectorService.create_segment_vector(args['keywords'], segment_document, dataset)
+            VectorService.create_segments_vector([args['keywords']], [segment_document], dataset)
         except Exception as e:
             logging.exception("create segment index failed")
             segment_document.enabled = False
@@ -1087,6 +1088,7 @@ class SegmentService:
         ).scalar()
         pre_segment_data_list = []
         segment_data_list = []
+        keywords_list = []
         for segment_item in segments:
             content = segment_item['content']
             doc_id = str(uuid.uuid4())
@@ -1119,15 +1121,13 @@ class SegmentService:
                 segment_document.answer = segment_item['answer']
             db.session.add(segment_document)
             segment_data_list.append(segment_document)
-            pre_segment_data = {
-                'segment': segment_document,
-                'keywords': segment_item['keywords']
-            }
-            pre_segment_data_list.append(pre_segment_data)
+
+            pre_segment_data_list.append(segment_document)
+            keywords_list.append(segment_item['keywords'])
 
         try:
             # save vector index
-            VectorService.multi_create_segment_vector(pre_segment_data_list, dataset)
+            VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset)
         except Exception as e:
             logging.exception("create segment index failed")
             for segment_document in segment_data_list:
@@ -1157,11 +1157,18 @@ class SegmentService:
                 db.session.commit()
                 # update segment index task
                 if args['keywords']:
-                    kw_index = IndexBuilder.get_index(dataset, 'economy')
-                    # delete from keyword index
-                    kw_index.delete_by_ids([segment.index_node_id])
-                    # save keyword index
-                    kw_index.update_segment_keywords_index(segment.index_node_id, segment.keywords)
+                    keyword = Keyword(dataset)
+                    keyword.delete_by_ids([segment.index_node_id])
+                    document = RAGDocument(
+                        page_content=segment.content,
+                        metadata={
+                            "doc_id": segment.index_node_id,
+                            "doc_hash": segment.index_node_hash,
+                            "document_id": segment.document_id,
+                            "dataset_id": segment.dataset_id,
+                        }
+                    )
+                    keyword.add_texts([document], keywords_list=[args['keywords']])
             else:
                 segment_hash = helper.generate_text_hash(content)
                 tokens = 0

+ 5 - 4
api/services/file_service.py

@@ -9,8 +9,8 @@ from flask_login import current_user
 from werkzeug.datastructures import FileStorage
 from werkzeug.exceptions import NotFound
 
-from core.data_loader.file_extractor import FileExtractor
 from core.file.upload_file_parser import UploadFileParser
+from core.rag.extractor.extract_processor import ExtractProcessor
 from extensions.ext_database import db
 from extensions.ext_storage import storage
 from models.account import Account
@@ -32,7 +32,8 @@ class FileService:
     def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile:
         extension = file.filename.split('.')[-1]
         etl_type = current_app.config['ETL_TYPE']
-        allowed_extensions = UNSTRUSTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS
+        allowed_extensions = UNSTRUSTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS if etl_type == 'Unstructured' \
+            else ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS
         if extension.lower() not in allowed_extensions:
             raise UnsupportedFileTypeError()
         elif only_image and extension.lower() not in IMAGE_EXTENSIONS:
@@ -136,7 +137,7 @@ class FileService:
         if extension.lower() not in allowed_extensions:
             raise UnsupportedFileTypeError()
 
-        text = FileExtractor.load(upload_file, return_text=True)
+        text = ExtractProcessor.load_from_upload_file(upload_file, return_text=True)
         text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
 
         return text
@@ -164,7 +165,7 @@ class FileService:
         return generator, upload_file.mime_type
 
     @staticmethod
-    def get_public_image_preview(file_id: str) -> str:
+    def get_public_image_preview(file_id: str) -> tuple[Generator, str]:
         upload_file = db.session.query(UploadFile) \
             .filter(UploadFile.id == file_id) \
             .first()

+ 13 - 62
api/services/hit_testing_service.py

@@ -1,21 +1,18 @@
 import logging
-import threading
 import time
 
 import numpy as np
-from flask import current_app
-from langchain.embeddings.base import Embeddings
-from langchain.schema import Document
 from sklearn.manifold import TSNE
 
 from core.embedding.cached_embedding import CacheEmbedding
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
-from core.rerank.rerank import RerankRunner
+from core.rag.datasource.entity.embedding import Embeddings
+from core.rag.datasource.retrieval_service import RetrievalService
+from core.rag.models.document import Document
 from extensions.ext_database import db
 from models.account import Account
 from models.dataset import Dataset, DatasetQuery, DocumentSegment
-from services.retrieval_service import RetrievalService
 
 default_retrieval_model = {
     'search_method': 'semantic_search',
@@ -28,6 +25,7 @@ default_retrieval_model = {
     'score_threshold_enabled': False
 }
 
+
 class HitTestingService:
     @classmethod
     def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict:
@@ -57,61 +55,15 @@ class HitTestingService:
 
         embeddings = CacheEmbedding(embedding_model)
 
-        all_documents = []
-        threads = []
-
-        # retrieval_model source with semantic
-        if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
-            embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
-                'flask_app': current_app._get_current_object(),
-                'dataset_id': str(dataset.id),
-                'query': query,
-                'top_k': retrieval_model['top_k'],
-                'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
-                'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,
-                'all_documents': all_documents,
-                'search_method': retrieval_model['search_method'],
-                'embeddings': embeddings
-            })
-            threads.append(embedding_thread)
-            embedding_thread.start()
-
-        # retrieval source with full text
-        if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
-            full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
-                'flask_app': current_app._get_current_object(),
-                'dataset_id': str(dataset.id),
-                'query': query,
-                'search_method': retrieval_model['search_method'],
-                'embeddings': embeddings,
-                'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
-                'top_k': retrieval_model['top_k'],
-                'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,
-                'all_documents': all_documents
-            })
-            threads.append(full_text_index_thread)
-            full_text_index_thread.start()
-
-        for thread in threads:
-            thread.join()
-
-        if retrieval_model['search_method'] == 'hybrid_search':
-            model_manager = ModelManager()
-            rerank_model_instance = model_manager.get_model_instance(
-                tenant_id=dataset.tenant_id,
-                provider=retrieval_model['reranking_model']['reranking_provider_name'],
-                model_type=ModelType.RERANK,
-                model=retrieval_model['reranking_model']['reranking_model_name']
-            )
-
-            rerank_runner = RerankRunner(rerank_model_instance)
-            all_documents = rerank_runner.run(
-                query=query,
-                documents=all_documents,
-                score_threshold=retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
-                top_n=retrieval_model['top_k'],
-                user=f"account-{account.id}"
-            )
+        all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
+                                                  dataset_id=dataset.id,
+                                                  query=query,
+                                                  top_k=retrieval_model['top_k'],
+                                                  score_threshold=retrieval_model['score_threshold']
+                                                  if retrieval_model['score_threshold_enabled'] else None,
+                                                  reranking_model=retrieval_model['reranking_model']
+                                                  if retrieval_model['reranking_enable'] else None
+                                                  )
 
         end = time.perf_counter()
         logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
@@ -203,4 +155,3 @@ class HitTestingService:
 
         if not query or len(query) > 250:
             raise ValueError('Query is required and cannot exceed 250 characters')
-

+ 0 - 119
api/services/retrieval_service.py

@@ -1,119 +0,0 @@
-from typing import Optional
-
-from flask import Flask, current_app
-from langchain.embeddings.base import Embeddings
-
-from core.index.vector_index.vector_index import VectorIndex
-from core.model_manager import ModelManager
-from core.model_runtime.entities.model_entities import ModelType
-from core.model_runtime.errors.invoke import InvokeAuthorizationError
-from core.rerank.rerank import RerankRunner
-from extensions.ext_database import db
-from models.dataset import Dataset
-
-default_retrieval_model = {
-    'search_method': 'semantic_search',
-    'reranking_enable': False,
-    'reranking_model': {
-        'reranking_provider_name': '',
-        'reranking_model_name': ''
-    },
-    'top_k': 2,
-    'score_threshold_enabled': False
-}
-
-
-class RetrievalService:
-
-    @classmethod
-    def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str,
-                         top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
-                         all_documents: list, search_method: str, embeddings: Embeddings):
-        with flask_app.app_context():
-            dataset = db.session.query(Dataset).filter(
-                Dataset.id == dataset_id
-            ).first()
-
-            vector_index = VectorIndex(
-                dataset=dataset,
-                config=current_app.config,
-                embeddings=embeddings
-            )
-
-            documents = vector_index.search(
-                query,
-                search_type='similarity_score_threshold',
-                search_kwargs={
-                    'k': top_k,
-                    'score_threshold': score_threshold,
-                    'filter': {
-                        'group_id': [dataset.id]
-                    }
-                }
-            )
-
-            if documents:
-                if reranking_model and search_method == 'semantic_search':
-                    try:
-                        model_manager = ModelManager()
-                        rerank_model_instance = model_manager.get_model_instance(
-                            tenant_id=dataset.tenant_id,
-                            provider=reranking_model['reranking_provider_name'],
-                            model_type=ModelType.RERANK,
-                            model=reranking_model['reranking_model_name']
-                        )
-                    except InvokeAuthorizationError:
-                        return
-
-                    rerank_runner = RerankRunner(rerank_model_instance)
-                    all_documents.extend(rerank_runner.run(
-                        query=query,
-                        documents=documents,
-                        score_threshold=score_threshold,
-                        top_n=len(documents)
-                    ))
-                else:
-                    all_documents.extend(documents)
-
-    @classmethod
-    def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str,
-                               top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
-                               all_documents: list, search_method: str, embeddings: Embeddings):
-        with flask_app.app_context():
-            dataset = db.session.query(Dataset).filter(
-                Dataset.id == dataset_id
-            ).first()
-
-            vector_index = VectorIndex(
-                dataset=dataset,
-                config=current_app.config,
-                embeddings=embeddings
-            )
-
-            documents = vector_index.search_by_full_text_index(
-                query,
-                search_type='similarity_score_threshold',
-                top_k=top_k
-            )
-            if documents:
-                if reranking_model and search_method == 'full_text_search':
-                    try:
-                        model_manager = ModelManager()
-                        rerank_model_instance = model_manager.get_model_instance(
-                            tenant_id=dataset.tenant_id,
-                            provider=reranking_model['reranking_provider_name'],
-                            model_type=ModelType.RERANK,
-                            model=reranking_model['reranking_model_name']
-                        )
-                    except InvokeAuthorizationError:
-                        return
-
-                    rerank_runner = RerankRunner(rerank_model_instance)
-                    all_documents.extend(rerank_runner.run(
-                        query=query,
-                        documents=documents,
-                        score_threshold=score_threshold,
-                        top_n=len(documents)
-                    ))
-                else:
-                    all_documents.extend(documents)

+ 31 - 54
api/services/vector_service.py

@@ -1,44 +1,18 @@
-
 from typing import Optional
 
-from langchain.schema import Document
-
-from core.index.index import IndexBuilder
+from core.rag.datasource.keyword.keyword_factory import Keyword
+from core.rag.datasource.vdb.vector_factory import Vector
+from core.rag.models.document import Document
 from models.dataset import Dataset, DocumentSegment
 
 
 class VectorService:
 
     @classmethod
-    def create_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset):
-        document = Document(
-            page_content=segment.content,
-            metadata={
-                "doc_id": segment.index_node_id,
-                "doc_hash": segment.index_node_hash,
-                "document_id": segment.document_id,
-                "dataset_id": segment.dataset_id,
-            }
-        )
-
-        # save vector index
-        index = IndexBuilder.get_index(dataset, 'high_quality')
-        if index:
-            index.add_texts([document], duplicate_check=True)
-
-        # save keyword index
-        index = IndexBuilder.get_index(dataset, 'economy')
-        if index:
-            if keywords and len(keywords) > 0:
-                index.create_segment_keywords(segment.index_node_id, keywords)
-            else:
-                index.add_texts([document])
-
-    @classmethod
-    def multi_create_segment_vector(cls, pre_segment_data_list: list, dataset: Dataset):
+    def create_segments_vector(cls, keywords_list: Optional[list[list[str]]],
+                               segments: list[DocumentSegment], dataset: Dataset):
         documents = []
-        for pre_segment_data in pre_segment_data_list:
-            segment = pre_segment_data['segment']
+        for segment in segments:
             document = Document(
                 page_content=segment.content,
                 metadata={
@@ -49,30 +23,26 @@ class VectorService:
                 }
             )
             documents.append(document)
-
-        # save vector index
-        index = IndexBuilder.get_index(dataset, 'high_quality')
-        if index:
-            index.add_texts(documents, duplicate_check=True)
+        if dataset.indexing_technique == 'high_quality':
+            # save vector index
+            vector = Vector(
+                dataset=dataset
+            )
+            vector.add_texts(documents, duplicate_check=True)
 
         # save keyword index
-        keyword_index = IndexBuilder.get_index(dataset, 'economy')
-        if keyword_index:
-            keyword_index.multi_create_segment_keywords(pre_segment_data_list)
+        keyword = Keyword(dataset)
+
+        if keywords_list and len(keywords_list) > 0:
+            keyword.add_texts(documents, keyword_list=keywords_list)
+        else:
+            keyword.add_texts(documents)
 
     @classmethod
     def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset):
         # update segment index task
-        vector_index = IndexBuilder.get_index(dataset, 'high_quality')
-        kw_index = IndexBuilder.get_index(dataset, 'economy')
-        # delete from vector index
-        if vector_index:
-            vector_index.delete_by_ids([segment.index_node_id])
 
-        # delete from keyword index
-        kw_index.delete_by_ids([segment.index_node_id])
-
-        # add new index
+        # format new index
         document = Document(
             page_content=segment.content,
             metadata={
@@ -82,13 +52,20 @@ class VectorService:
                 "dataset_id": segment.dataset_id,
             }
         )
+        if dataset.indexing_technique == 'high_quality':
+            # update vector index
+            vector = Vector(
+                dataset=dataset
+            )
+            vector.delete_by_ids([segment.index_node_id])
+            vector.add_texts([document], duplicate_check=True)
 
-        # save vector index
-        if vector_index:
-            vector_index.add_texts([document], duplicate_check=True)
+        # update keyword index
+        keyword = Keyword(dataset)
+        keyword.delete_by_ids([segment.index_node_id])
 
         # save keyword index
         if keywords and len(keywords) > 0:
-            kw_index.create_segment_keywords(segment.index_node_id, keywords)
+            keyword.add_texts([document], keywords_list=[keywords])
         else:
-            kw_index.add_texts([document])
+            keyword.add_texts([document])

+ 5 - 11
api/tasks/add_document_to_index_task.py

@@ -4,10 +4,10 @@ import time
 
 import click
 from celery import shared_task
-from langchain.schema import Document
 from werkzeug.exceptions import NotFound
 
-from core.index.index import IndexBuilder
+from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
+from core.rag.models.document import Document
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from models.dataset import Document as DatasetDocument
@@ -60,15 +60,9 @@ def add_document_to_index_task(dataset_document_id: str):
         if not dataset:
             raise Exception('Document has no dataset')
 
-        # save vector index
-        index = IndexBuilder.get_index(dataset, 'high_quality')
-        if index:
-            index.add_texts(documents)
-
-        # save keyword index
-        index = IndexBuilder.get_index(dataset, 'economy')
-        if index:
-            index.add_texts(documents)
+        index_type = dataset.doc_form
+        index_processor = IndexProcessorFactory(index_type).init_index_processor()
+        index_processor.load(dataset, documents)
 
         end_at = time.perf_counter()
         logging.info(

部分文件因文件數量過多而無法顯示