Browse Source

feat: upgrade langchain (#430)

Co-authored-by: jyong <718720800@qq.com>
John Wang 1 year ago
parent
commit
3241e4015b
91 changed files with 2689 additions and 3139 deletions
  1. 1 2
      api/app.py
  2. 35 0
      api/commands.py
  3. 2 0
      api/config.py
  4. 11 10
      api/controllers/console/datasets/data_source.py
  5. 2 28
      api/controllers/console/datasets/file.py
  6. 7 2
      api/controllers/console/version.py
  7. 0 20
      api/core/__init__.py
  8. 8 11
      api/core/agent/agent_builder.py
  9. 1 29
      api/core/callback_handler/agent_loop_gather_callback_handler.py
  10. 1 50
      api/core/callback_handler/dataset_tool_callback_handler.py
  11. 8 21
      api/core/callback_handler/index_tool_callback_handler.py
  12. 31 85
      api/core/callback_handler/llm_callback_handler.py
  13. 12 70
      api/core/callback_handler/main_chain_gather_callback_handler.py
  14. 36 9
      api/core/callback_handler/std_out_callback_handler.py
  15. 2 4
      api/core/chain/chain_builder.py
  16. 6 4
      api/core/chain/llm_router_chain.py
  17. 10 8
      api/core/chain/main_chain_builder.py
  18. 62 16
      api/core/chain/multi_dataset_router_chain.py
  19. 7 2
      api/core/chain/sensitive_word_avoidance_chain.py
  20. 12 3
      api/core/chain/tool_chain.py
  21. 42 17
      api/core/completion.py
  22. 4 4
      api/core/conversation_message_task.py
  23. 43 0
      api/core/data_loader/file_extractor.py
  24. 67 0
      api/core/data_loader/loader/csv.py
  25. 43 0
      api/core/data_loader/loader/excel.py
  26. 35 0
      api/core/data_loader/loader/html.py
  27. 134 0
      api/core/data_loader/loader/markdown.py
  28. 236 236
      api/core/data_loader/loader/notion.py
  29. 55 0
      api/core/data_loader/loader/pdf.py
  30. 35 45
      api/core/docstore/dataset_docstore.py
  31. 0 51
      api/core/docstore/empty_docstore.py
  32. 72 0
      api/core/embedding/cached_embedding.py
  33. 0 214
      api/core/embedding/openai_embedding.py
  34. 59 0
      api/core/index/base.py
  35. 41 0
      api/core/index/index.py
  36. 0 60
      api/core/index/index_builder.py
  37. 0 159
      api/core/index/keyword_table/jieba_keyword_table.py
  38. 0 135
      api/core/index/keyword_table_index.py
  39. 33 0
      api/core/index/keyword_table_index/jieba_keyword_table_handler.py
  40. 238 0
      api/core/index/keyword_table_index/keyword_table_index.py
  41. 0 0
      api/core/index/keyword_table_index/stopwords.py
  42. 0 79
      api/core/index/query/synthesizer.py
  43. 0 22
      api/core/index/readers/html_parser.py
  44. 0 111
      api/core/index/readers/markdown_parser.py
  45. 0 56
      api/core/index/readers/pdf_parser.py
  46. 0 33
      api/core/index/readers/xlsx_parser.py
  47. 0 136
      api/core/index/vector_index.py
  48. 175 0
      api/core/index/vector_index/base.py
  49. 116 0
      api/core/index/vector_index/qdrant_vector_index.py
  50. 69 0
      api/core/index/vector_index/vector_index.py
  51. 132 0
      api/core/index/vector_index/weaviate_vector_index.py
  52. 262 283
      api/core/indexing_runner.py
  53. 17 14
      api/core/llm/llm_builder.py
  54. 4 1
      api/core/llm/provider/azure_provider.py
  55. 13 50
      api/core/llm/streamable_azure_chat_open_ai.py
  56. 13 6
      api/core/llm/streamable_azure_open_ai.py
  57. 14 48
      api/core/llm/streamable_chat_open_ai.py
  58. 13 5
      api/core/llm/streamable_open_ai.py
  59. 1 1
      api/core/memory/read_only_conversation_token_db_string_buffer_shared_memory.py
  60. 0 19
      api/core/prompt/prompts.py
  61. 0 0
      api/core/spiltter/fixed_text_splitter.py
  62. 87 0
      api/core/tool/dataset_index_tool.py
  63. 0 73
      api/core/tool/dataset_tool_builder.py
  64. 0 43
      api/core/tool/llama_index_tool.py
  65. 0 34
      api/core/vector_store/base.py
  66. 69 0
      api/core/vector_store/qdrant_vector_store.py
  67. 0 147
      api/core/vector_store/qdrant_vector_store_client.py
  68. 0 62
      api/core/vector_store/vector_store.py
  69. 0 66
      api/core/vector_store/vector_store_index_query.py
  70. 38 0
      api/core/vector_store/weaviate_vector_store.py
  71. 0 270
      api/core/vector_store/weaviate_vector_store_client.py
  72. 0 7
      api/extensions/ext_vector_store.py
  73. 6 0
      api/libs/helper.py
  74. 0 2
      api/models/account.py
  75. 30 2
      api/models/dataset.py
  76. 5 4
      api/requirements.txt
  77. 0 1
      api/services/app_model_config_service.py
  78. 0 3
      api/services/dataset_service.py
  79. 44 36
      api/services/hit_testing_service.py
  80. 33 48
      api/tasks/add_document_to_index_task.py
  81. 27 32
      api/tasks/add_segment_to_index_task.py
  82. 13 20
      api/tasks/clean_dataset_task.py
  83. 7 6
      api/tasks/clean_document_task.py
  84. 7 6
      api/tasks/clean_notion_document_task.py
  85. 36 36
      api/tasks/deal_dataset_vector_index_task.py
  86. 23 23
      api/tasks/document_indexing_sync_task.py
  87. 6 16
      api/tasks/document_indexing_task.py
  88. 11 20
      api/tasks/document_indexing_update_task.py
  89. 4 9
      api/tasks/recover_document_indexing_task.py
  90. 5 6
      api/tasks/remove_document_from_index_task.py
  91. 18 8
      api/tasks/remove_segment_from_index_task.py

+ 1 - 2
api/app.py

@@ -14,7 +14,7 @@ from flask import Flask, request, Response, session
 import flask_login
 import flask_login
 from flask_cors import CORS
 from flask_cors import CORS
 
 
-from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_vector_store, ext_migrate, \
+from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
     ext_database, ext_storage
     ext_database, ext_storage
 from extensions.ext_database import db
 from extensions.ext_database import db
 from extensions.ext_login import login_manager
 from extensions.ext_login import login_manager
@@ -79,7 +79,6 @@ def initialize_extensions(app):
     ext_database.init_app(app)
     ext_database.init_app(app)
     ext_migrate.init(app, db)
     ext_migrate.init(app, db)
     ext_redis.init_app(app)
     ext_redis.init_app(app)
-    ext_vector_store.init_app(app)
     ext_storage.init_app(app)
     ext_storage.init_app(app)
     ext_celery.init_app(app)
     ext_celery.init_app(app)
     ext_session.init_app(app)
     ext_session.init_app(app)

+ 35 - 0
api/commands.py

@@ -1,15 +1,19 @@
 import datetime
 import datetime
+import logging
 import random
 import random
 import string
 import string
 
 
 import click
 import click
 from flask import current_app
 from flask import current_app
+from werkzeug.exceptions import NotFound
 
 
+from core.index.index import IndexBuilder
 from libs.password import password_pattern, valid_password, hash_password
 from libs.password import password_pattern, valid_password, hash_password
 from libs.helper import email as email_validate
 from libs.helper import email as email_validate
 from extensions.ext_database import db
 from extensions.ext_database import db
 from libs.rsa import generate_key_pair
 from libs.rsa import generate_key_pair
 from models.account import InvitationCode, Tenant
 from models.account import InvitationCode, Tenant
+from models.dataset import Dataset
 from models.model import Account
 from models.model import Account
 import secrets
 import secrets
 import base64
 import base64
@@ -159,8 +163,39 @@ def generate_upper_string():
     return result
     return result
 
 
 
 
+@click.command('recreate-all-dataset-indexes', help='Recreate all dataset indexes.')
+def recreate_all_dataset_indexes():
+    click.echo(click.style('Start recreate all dataset indexes.', fg='green'))
+    recreate_count = 0
+
+    page = 1
+    while True:
+        try:
+            datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality')\
+                .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
+        except NotFound:
+            break
+
+        page += 1
+        for dataset in datasets:
+            try:
+                click.echo('Recreating dataset index: {}'.format(dataset.id))
+                index = IndexBuilder.get_index(dataset, 'high_quality')
+                if index and index._is_origin():
+                    index.recreate_dataset(dataset)
+                    recreate_count += 1
+                else:
+                    click.echo('passed.')
+            except Exception as e:
+                click.echo(click.style('Recreate dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
+                continue
+
+    click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green'))
+
+
 def register_commands(app):
 def register_commands(app):
     app.cli.add_command(reset_password)
     app.cli.add_command(reset_password)
     app.cli.add_command(reset_email)
     app.cli.add_command(reset_email)
     app.cli.add_command(generate_invitation_codes)
     app.cli.add_command(generate_invitation_codes)
     app.cli.add_command(reset_encrypt_key_pair)
     app.cli.add_command(reset_encrypt_key_pair)
+    app.cli.add_command(recreate_all_dataset_indexes)

+ 2 - 0
api/config.py

@@ -187,11 +187,13 @@ class Config:
         # For temp use only
         # For temp use only
         # set default LLM provider, default is 'openai', support `azure_openai`
         # set default LLM provider, default is 'openai', support `azure_openai`
         self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER')
         self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER')
+
         # notion import setting
         # notion import setting
         self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
         self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
         self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
         self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
         self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE')
         self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE')
         self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
         self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
+        self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN')
 
 
 
 
 class CloudEditionConfig(Config):
 class CloudEditionConfig(Config):

+ 11 - 10
api/controllers/console/datasets/data_source.py

@@ -10,11 +10,10 @@ from werkzeug.exceptions import NotFound
 from controllers.console import api
 from controllers.console import api
 from controllers.console.setup import setup_required
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
 from controllers.console.wraps import account_initialization_required
-from core.data_source.notion import NotionPageReader
+from core.data_loader.loader.notion import NotionLoader
 from core.indexing_runner import IndexingRunner
 from core.indexing_runner import IndexingRunner
 from extensions.ext_database import db
 from extensions.ext_database import db
 from libs.helper import TimestampField
 from libs.helper import TimestampField
-from libs.oauth_data_source import NotionOAuth
 from models.dataset import Document
 from models.dataset import Document
 from models.source import DataSourceBinding
 from models.source import DataSourceBinding
 from services.dataset_service import DatasetService, DocumentService
 from services.dataset_service import DatasetService, DocumentService
@@ -232,15 +231,17 @@ class DataSourceNotionApi(Resource):
         ).first()
         ).first()
         if not data_source_binding:
         if not data_source_binding:
             raise NotFound('Data source binding not found.')
             raise NotFound('Data source binding not found.')
-        reader = NotionPageReader(integration_token=data_source_binding.access_token)
-        if page_type == 'page':
-            page_content = reader.read_page(page_id)
-        elif page_type == 'database':
-            page_content = reader.query_database_data(page_id)
-        else:
-            page_content = ""
+
+        loader = NotionLoader(
+            notion_access_token=data_source_binding.access_token,
+            notion_workspace_id=workspace_id,
+            notion_obj_id=page_id,
+            notion_page_type=page_type
+        )
+
+        text_docs = loader.load()
         return {
         return {
-            'content': page_content
+            'content': "\n".join([doc.page_content for doc in text_docs])
         }, 200
         }, 200
 
 
     @setup_required
     @setup_required

+ 2 - 28
api/controllers/console/datasets/file.py

@@ -17,9 +17,7 @@ from controllers.console.datasets.error import NoFileUploadedError, TooManyFiles
     UnsupportedFileTypeError
     UnsupportedFileTypeError
 from controllers.console.setup import setup_required
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
 from controllers.console.wraps import account_initialization_required
-from core.index.readers.html_parser import HTMLParser
-from core.index.readers.pdf_parser import PDFParser
-from core.index.readers.xlsx_parser import XLSXParser
+from core.data_loader.file_extractor import FileExtractor
 from extensions.ext_storage import storage
 from extensions.ext_storage import storage
 from libs.helper import TimestampField
 from libs.helper import TimestampField
 from extensions.ext_database import db
 from extensions.ext_database import db
@@ -123,31 +121,7 @@ class FilePreviewApi(Resource):
         if extension not in ALLOWED_EXTENSIONS:
         if extension not in ALLOWED_EXTENSIONS:
             raise UnsupportedFileTypeError()
             raise UnsupportedFileTypeError()
 
 
-        with tempfile.TemporaryDirectory() as temp_dir:
-            suffix = Path(upload_file.key).suffix
-            filepath = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
-            storage.download(upload_file.key, filepath)
-
-            if extension == 'pdf':
-                parser = PDFParser({'upload_file': upload_file})
-                text = parser.parse_file(Path(filepath))
-            elif extension in ['html', 'htm']:
-                # Use BeautifulSoup to extract text
-                parser = HTMLParser()
-                text = parser.parse_file(Path(filepath))
-            elif extension == 'xlsx':
-                parser = XLSXParser()
-                text = parser.parse_file(filepath)
-            else:
-                # ['txt', 'markdown', 'md']
-                with open(filepath, "rb") as fp:
-                    data = fp.read()
-                    encoding = chardet.detect(data)['encoding']
-                    if encoding:
-                        text = data.decode(encoding=encoding).strip() if data else ''
-                    else:
-                        text = data.decode(encoding='utf-8').strip() if data else ''
-
+        text = FileExtractor.load(upload_file, return_text=True)
         text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
         text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
         return {'content': text}
         return {'content': text}
 
 

+ 7 - 2
api/controllers/console/version.py

@@ -32,8 +32,13 @@ class VersionApi(Resource):
                 'current_version': args.get('current_version')
                 'current_version': args.get('current_version')
             })
             })
         except Exception as error:
         except Exception as error:
-            logging.exception("Check update error.")
-            raise InternalServerError()
+            logging.warning("Check update version error: {}.".format(str(error)))
+            return {
+                'version': args.get('current_version'),
+                'release_date': '',
+                'release_notes': '',
+                'can_auto_update': False
+            }
 
 
         content = json.loads(response.content)
         content = json.loads(response.content)
         return {
         return {

+ 0 - 20
api/core/__init__.py

@@ -3,19 +3,11 @@ from typing import Optional
 
 
 import langchain
 import langchain
 from flask import Flask
 from flask import Flask
-from jieba.analyse import default_tfidf
-from langchain import set_handler
 from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING
 from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING
-from llama_index import IndexStructType, QueryMode
-from llama_index.indices.registry import INDEX_STRUT_TYPE_TO_QUERY_MAP
 from pydantic import BaseModel
 from pydantic import BaseModel
 
 
 from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
 from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
-from core.index.keyword_table.jieba_keyword_table import GPTJIEBAKeywordTableIndex
-from core.index.keyword_table.stopwords import STOPWORDS
 from core.prompt.prompt_template import OneLineFormatter
 from core.prompt.prompt_template import OneLineFormatter
-from core.vector_store.vector_store import VectorStore
-from core.vector_store.vector_store_index_query import EnhanceGPTVectorStoreIndexQuery
 
 
 
 
 class HostedOpenAICredential(BaseModel):
 class HostedOpenAICredential(BaseModel):
@@ -32,21 +24,9 @@ hosted_llm_credentials = HostedLLMCredentials()
 def init_app(app: Flask):
 def init_app(app: Flask):
     formatter = OneLineFormatter()
     formatter = OneLineFormatter()
     DEFAULT_FORMATTER_MAPPING['f-string'] = formatter.format
     DEFAULT_FORMATTER_MAPPING['f-string'] = formatter.format
-    INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.KEYWORD_TABLE] = GPTJIEBAKeywordTableIndex.get_query_map()
-    INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.WEAVIATE] = {
-        QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery,
-        QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery,
-    }
-    INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.QDRANT] = {
-        QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery,
-        QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery,
-    }
-
-    default_tfidf.stop_words = STOPWORDS
 
 
     if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
     if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
         langchain.verbose = True
         langchain.verbose = True
-        set_handler(DifyStdOutCallbackHandler())
 
 
     if app.config.get("OPENAI_API_KEY"):
     if app.config.get("OPENAI_API_KEY"):
         hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))
         hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))

+ 8 - 11
api/core/agent/agent_builder.py

@@ -2,7 +2,7 @@ from typing import Optional
 
 
 from langchain import LLMChain
 from langchain import LLMChain
 from langchain.agents import ZeroShotAgent, AgentExecutor, ConversationalAgent
 from langchain.agents import ZeroShotAgent, AgentExecutor, ConversationalAgent
-from langchain.callbacks import CallbackManager
+from langchain.callbacks.manager import CallbackManager
 from langchain.memory.chat_memory import BaseChatMemory
 from langchain.memory.chat_memory import BaseChatMemory
 
 
 from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
 from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
@@ -16,23 +16,20 @@ class AgentBuilder:
     def to_agent_chain(cls, tenant_id: str, tools, memory: Optional[BaseChatMemory],
     def to_agent_chain(cls, tenant_id: str, tools, memory: Optional[BaseChatMemory],
                        dataset_tool_callback_handler: DatasetToolCallbackHandler,
                        dataset_tool_callback_handler: DatasetToolCallbackHandler,
                        agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler):
                        agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler):
-        llm_callback_manager = CallbackManager([agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()])
         llm = LLMBuilder.to_llm(
         llm = LLMBuilder.to_llm(
             tenant_id=tenant_id,
             tenant_id=tenant_id,
             model_name=agent_loop_gather_callback_handler.model_name,
             model_name=agent_loop_gather_callback_handler.model_name,
             temperature=0,
             temperature=0,
             max_tokens=1024,
             max_tokens=1024,
-            callback_manager=llm_callback_manager
+            callbacks=[agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()]
         )
         )
 
 
-        tool_callback_manager = CallbackManager([
-            agent_loop_gather_callback_handler,
-            dataset_tool_callback_handler,
-            DifyStdOutCallbackHandler()
-        ])
-
         for tool in tools:
         for tool in tools:
-            tool.callback_manager = tool_callback_manager
+            tool.callbacks = [
+                agent_loop_gather_callback_handler,
+                dataset_tool_callback_handler,
+                DifyStdOutCallbackHandler()
+            ]
 
 
         prompt = cls.build_agent_prompt_template(
         prompt = cls.build_agent_prompt_template(
             tools=tools,
             tools=tools,
@@ -54,7 +51,7 @@ class AgentBuilder:
             tools=tools,
             tools=tools,
             agent=agent,
             agent=agent,
             memory=memory,
             memory=memory,
-            callback_manager=agent_callback_manager,
+            callbacks=agent_callback_manager,
             max_iterations=6,
             max_iterations=6,
             early_stopping_method="generate",
             early_stopping_method="generate",
             # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
             # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit

+ 1 - 29
api/core/callback_handler/agent_loop_gather_callback_handler.py

@@ -12,6 +12,7 @@ from core.conversation_message_task import ConversationMessageTask
 
 
 class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
 class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
     """Callback Handler that prints to std out."""
     """Callback Handler that prints to std out."""
+    raise_error: bool = True
 
 
     def __init__(self, model_name, conversation_message_task: ConversationMessageTask) -> None:
     def __init__(self, model_name, conversation_message_task: ConversationMessageTask) -> None:
         """Initialize callback handler."""
         """Initialize callback handler."""
@@ -64,10 +65,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
             self._current_loop.completion = response.generations[0][0].text
             self._current_loop.completion = response.generations[0][0].text
             self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
             self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
 
 
-    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
-        """Do nothing."""
-        pass
-
     def on_llm_error(
     def on_llm_error(
         self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
         self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
     ) -> None:
     ) -> None:
@@ -75,21 +72,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
         self._agent_loops = []
         self._agent_loops = []
         self._current_loop = None
         self._current_loop = None
 
 
-    def on_chain_start(
-        self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
-    ) -> None:
-        """Print out that we are entering a chain."""
-        pass
-
-    def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
-        """Print out that we finished a chain."""
-        pass
-
-    def on_chain_error(
-        self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
-    ) -> None:
-        logging.error(error)
-
     def on_tool_start(
     def on_tool_start(
         self,
         self,
         serialized: Dict[str, Any],
         serialized: Dict[str, Any],
@@ -151,16 +133,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
         self._agent_loops = []
         self._agent_loops = []
         self._current_loop = None
         self._current_loop = None
 
 
-    def on_text(
-        self,
-        text: str,
-        color: Optional[str] = None,
-        end: str = "",
-        **kwargs: Optional[str],
-    ) -> None:
-        """Run on additional input from chains and agents."""
-        pass
-
     def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
     def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
         """Run on agent end."""
         """Run on agent end."""
         # Final Answer
         # Final Answer

+ 1 - 50
api/core/callback_handler/dataset_tool_callback_handler.py

@@ -3,7 +3,6 @@ import logging
 from typing import Any, Dict, List, Union, Optional
 from typing import Any, Dict, List, Union, Optional
 
 
 from langchain.callbacks.base import BaseCallbackHandler
 from langchain.callbacks.base import BaseCallbackHandler
-from langchain.schema import AgentAction, AgentFinish, LLMResult
 
 
 from core.callback_handler.entity.dataset_query import DatasetQueryObj
 from core.callback_handler.entity.dataset_query import DatasetQueryObj
 from core.conversation_message_task import ConversationMessageTask
 from core.conversation_message_task import ConversationMessageTask
@@ -11,6 +10,7 @@ from core.conversation_message_task import ConversationMessageTask
 
 
 class DatasetToolCallbackHandler(BaseCallbackHandler):
 class DatasetToolCallbackHandler(BaseCallbackHandler):
     """Callback Handler that prints to std out."""
     """Callback Handler that prints to std out."""
+    raise_error: bool = True
 
 
     def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
     def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
         """Initialize callback handler."""
         """Initialize callback handler."""
@@ -66,52 +66,3 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
     ) -> None:
     ) -> None:
         """Do nothing."""
         """Do nothing."""
         logging.error(error)
         logging.error(error)
-
-    def on_chain_start(
-        self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
-    ) -> None:
-        pass
-
-    def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
-        pass
-
-    def on_chain_error(
-        self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
-    ) -> None:
-        pass
-
-    def on_llm_start(
-        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
-    ) -> None:
-        pass
-
-    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
-        pass
-
-    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
-        """Do nothing."""
-        pass
-
-    def on_llm_error(
-        self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
-    ) -> None:
-        logging.error(error)
-
-    def on_agent_action(
-        self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
-    ) -> Any:
-        pass
-
-    def on_text(
-        self,
-        text: str,
-        color: Optional[str] = None,
-        end: str = "",
-        **kwargs: Optional[str],
-    ) -> None:
-        """Run on additional input from chains and agents."""
-        pass
-
-    def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
-        """Run on agent end."""
-        pass

+ 8 - 21
api/core/callback_handler/index_tool_callback_handler.py

@@ -1,39 +1,26 @@
-from llama_index import Response
+from typing import List
+
+from langchain.schema import Document
 
 
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.dataset import DocumentSegment
 from models.dataset import DocumentSegment
 
 
 
 
-class IndexToolCallbackHandler:
-
-    def __init__(self) -> None:
-        self._response = None
-
-    @property
-    def response(self) -> Response:
-        return self._response
-
-    def on_tool_end(self, response: Response) -> None:
-        """Handle tool end."""
-        self._response = response
-
-
-class DatasetIndexToolCallbackHandler(IndexToolCallbackHandler):
+class DatasetIndexToolCallbackHandler:
     """Callback handler for dataset tool."""
     """Callback handler for dataset tool."""
 
 
     def __init__(self, dataset_id: str) -> None:
     def __init__(self, dataset_id: str) -> None:
-        super().__init__()
         self.dataset_id = dataset_id
         self.dataset_id = dataset_id
 
 
-    def on_tool_end(self, response: Response) -> None:
+    def on_tool_end(self, documents: List[Document]) -> None:
         """Handle tool end."""
         """Handle tool end."""
-        for node in response.source_nodes:
-            index_node_id = node.node.doc_id
+        for document in documents:
+            doc_id = document.metadata['doc_id']
 
 
             # add hit count to document segment
             # add hit count to document segment
             db.session.query(DocumentSegment).filter(
             db.session.query(DocumentSegment).filter(
                 DocumentSegment.dataset_id == self.dataset_id,
                 DocumentSegment.dataset_id == self.dataset_id,
-                DocumentSegment.index_node_id == index_node_id
+                DocumentSegment.index_node_id == doc_id
             ).update(
             ).update(
                 {DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
                 {DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
                 synchronize_session=False
                 synchronize_session=False

+ 31 - 85
api/core/callback_handler/llm_callback_handler.py

@@ -3,7 +3,7 @@ import time
 from typing import Any, Dict, List, Union, Optional
 from typing import Any, Dict, List, Union, Optional
 
 
 from langchain.callbacks.base import BaseCallbackHandler
 from langchain.callbacks.base import BaseCallbackHandler
-from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage
+from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage, BaseMessage
 
 
 from core.callback_handler.entity.llm_message import LLMMessage
 from core.callback_handler.entity.llm_message import LLMMessage
 from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
 from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
@@ -12,6 +12,7 @@ from core.llm.streamable_open_ai import StreamableOpenAI
 
 
 
 
 class LLMCallbackHandler(BaseCallbackHandler):
 class LLMCallbackHandler(BaseCallbackHandler):
+    raise_error: bool = True
 
 
     def __init__(self, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
     def __init__(self, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
                  conversation_message_task: ConversationMessageTask):
                  conversation_message_task: ConversationMessageTask):
@@ -25,41 +26,41 @@ class LLMCallbackHandler(BaseCallbackHandler):
         """Whether to call verbose callbacks even if verbose is False."""
         """Whether to call verbose callbacks even if verbose is False."""
         return True
         return True
 
 
+    def on_chat_model_start(
+            self,
+            serialized: Dict[str, Any],
+            messages: List[List[BaseMessage]],
+            **kwargs: Any
+    ) -> Any:
+        self.start_at = time.perf_counter()
+        real_prompts = []
+        for message in messages[0]:
+            if message.type == 'human':
+                role = 'user'
+            elif message.type == 'ai':
+                role = 'assistant'
+            else:
+                role = 'system'
+
+            real_prompts.append({
+                "role": role,
+                "text": message.content
+            })
+
+        self.llm_message.prompt = real_prompts
+        self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages[0])
+
     def on_llm_start(
     def on_llm_start(
         self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
         self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
     ) -> None:
     ) -> None:
         self.start_at = time.perf_counter()
         self.start_at = time.perf_counter()
 
 
-        if 'Chat' in serialized['name']:
-            real_prompts = []
-            messages = []
-            for prompt in prompts:
-                role, content = prompt.split(': ', maxsplit=1)
-                if role == 'human':
-                    role = 'user'
-                    message = HumanMessage(content=content)
-                elif role == 'ai':
-                    role = 'assistant'
-                    message = AIMessage(content=content)
-                else:
-                    message = SystemMessage(content=content)
-
-                real_prompt = {
-                    "role": role,
-                    "text": content
-                }
-                real_prompts.append(real_prompt)
-                messages.append(message)
-
-            self.llm_message.prompt = real_prompts
-            self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages)
-        else:
-            self.llm_message.prompt = [{
-                "role": 'user',
-                "text": prompts[0]
-            }]
+        self.llm_message.prompt = [{
+            "role": 'user',
+            "text": prompts[0]
+        }]
 
 
-            self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0])
+        self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0])
 
 
     def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
     def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
         end_at = time.perf_counter()
         end_at = time.perf_counter()
@@ -95,58 +96,3 @@ class LLMCallbackHandler(BaseCallbackHandler):
                 self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
                 self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
         else:
         else:
             logging.error(error)
             logging.error(error)
-
-    def on_chain_start(
-            self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
-    ) -> None:
-        pass
-
-    def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
-        pass
-
-    def on_chain_error(
-            self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
-    ) -> None:
-        pass
-
-    def on_tool_start(
-            self,
-            serialized: Dict[str, Any],
-            input_str: str,
-            **kwargs: Any,
-    ) -> None:
-        pass
-
-    def on_agent_action(
-            self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
-    ) -> Any:
-        pass
-
-    def on_tool_end(
-            self,
-            output: str,
-            color: Optional[str] = None,
-            observation_prefix: Optional[str] = None,
-            llm_prefix: Optional[str] = None,
-            **kwargs: Any,
-    ) -> None:
-        pass
-
-    def on_tool_error(
-            self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
-    ) -> None:
-        pass
-
-    def on_text(
-            self,
-            text: str,
-            color: Optional[str] = None,
-            end: str = "",
-            **kwargs: Optional[str],
-    ) -> None:
-        pass
-
-    def on_agent_finish(
-            self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
-    ) -> None:
-        pass

+ 12 - 70
api/core/callback_handler/main_chain_gather_callback_handler.py

@@ -1,10 +1,9 @@
 import logging
 import logging
 import time
 import time
 
 
-from typing import Any, Dict, List, Union, Optional
+from typing import Any, Dict, Union
 
 
 from langchain.callbacks.base import BaseCallbackHandler
 from langchain.callbacks.base import BaseCallbackHandler
-from langchain.schema import AgentAction, AgentFinish, LLMResult
 
 
 from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
 from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
 from core.callback_handler.entity.chain_result import ChainResult
 from core.callback_handler.entity.chain_result import ChainResult
@@ -14,6 +13,7 @@ from core.conversation_message_task import ConversationMessageTask
 
 
 class MainChainGatherCallbackHandler(BaseCallbackHandler):
 class MainChainGatherCallbackHandler(BaseCallbackHandler):
     """Callback Handler that prints to std out."""
     """Callback Handler that prints to std out."""
+    raise_error: bool = True
 
 
     def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
     def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
         """Initialize callback handler."""
         """Initialize callback handler."""
@@ -50,13 +50,15 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
     ) -> None:
     ) -> None:
         """Print out that we are entering a chain."""
         """Print out that we are entering a chain."""
         if not self._current_chain_result:
         if not self._current_chain_result:
-            self._current_chain_result = ChainResult(
-                type=serialized['name'],
-                prompt=inputs,
-                started_at=time.perf_counter()
-            )
-            self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
-            self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message
+            chain_type = serialized['id'][-1]
+            if chain_type:
+                self._current_chain_result = ChainResult(
+                    type=chain_type,
+                    prompt=inputs,
+                    started_at=time.perf_counter()
+                )
+                self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
+                self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message
 
 
     def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
     def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
         """Print out that we finished a chain."""
         """Print out that we finished a chain."""
@@ -74,64 +76,4 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
         self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
         self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
     ) -> None:
     ) -> None:
         logging.error(error)
         logging.error(error)
-        self.clear_chain_results()
-
-    def on_llm_start(
-        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
-    ) -> None:
-        pass
-
-    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
-        pass
-
-    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
-        """Do nothing."""
-        pass
-
-    def on_llm_error(
-        self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
-    ) -> None:
-        logging.error(error)
-
-    def on_tool_start(
-        self,
-        serialized: Dict[str, Any],
-        input_str: str,
-        **kwargs: Any,
-    ) -> None:
-        pass
-
-    def on_agent_action(
-        self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
-    ) -> Any:
-        pass
-
-    def on_tool_end(
-        self,
-        output: str,
-        color: Optional[str] = None,
-        observation_prefix: Optional[str] = None,
-        llm_prefix: Optional[str] = None,
-        **kwargs: Any,
-    ) -> None:
-        pass
-
-    def on_tool_error(
-        self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
-    ) -> None:
-        """Do nothing."""
-        logging.error(error)
-
-    def on_text(
-        self,
-        text: str,
-        color: Optional[str] = None,
-        end: str = "",
-        **kwargs: Optional[str],
-    ) -> None:
-        """Run on additional input from chains and agents."""
-        pass
-
-    def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
-        """Run on agent end."""
-        pass
+        self.clear_chain_results()

+ 36 - 9
api/core/callback_handler/std_out_callback_handler.py

@@ -1,9 +1,10 @@
+import os
 import sys
 import sys
 from typing import Any, Dict, List, Optional, Union
 from typing import Any, Dict, List, Optional, Union
 
 
 from langchain.callbacks.base import BaseCallbackHandler
 from langchain.callbacks.base import BaseCallbackHandler
 from langchain.input import print_text
 from langchain.input import print_text
-from langchain.schema import AgentAction, AgentFinish, LLMResult
+from langchain.schema import AgentAction, AgentFinish, LLMResult, BaseMessage
 
 
 
 
 class DifyStdOutCallbackHandler(BaseCallbackHandler):
 class DifyStdOutCallbackHandler(BaseCallbackHandler):
@@ -13,17 +14,23 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
         """Initialize callback handler."""
         """Initialize callback handler."""
         self.color = color
         self.color = color
 
 
+    def on_chat_model_start(
+            self,
+            serialized: Dict[str, Any],
+            messages: List[List[BaseMessage]],
+            **kwargs: Any
+    ) -> Any:
+        print_text("\n[on_chat_model_start]\n", color='blue')
+        for sub_messages in messages:
+            for sub_message in sub_messages:
+                print_text(str(sub_message) + "\n", color='blue')
+
     def on_llm_start(
     def on_llm_start(
         self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
         self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
     ) -> None:
     ) -> None:
         """Print out the prompts."""
         """Print out the prompts."""
         print_text("\n[on_llm_start]\n", color='blue')
         print_text("\n[on_llm_start]\n", color='blue')
-
-        if 'Chat' in serialized['name']:
-            for prompt in prompts:
-                print_text(prompt + "\n", color='blue')
-        else:
-            print_text(prompts[0] + "\n", color='blue')
+        print_text(prompts[0] + "\n", color='blue')
 
 
     def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
     def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
         """Do nothing."""
         """Do nothing."""
@@ -44,8 +51,8 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
         self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
         self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
     ) -> None:
     ) -> None:
         """Print out that we are entering a chain."""
         """Print out that we are entering a chain."""
-        class_name = serialized["name"]
-        print_text("\n[on_chain_start]\nChain: " + class_name + "\nInputs: " + str(inputs) + "\n", color='pink')
+        chain_type = serialized['id'][-1]
+        print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink')
 
 
     def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
     def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
         """Print out that we finished a chain."""
         """Print out that we finished a chain."""
@@ -117,6 +124,26 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
         """Run on agent end."""
         """Run on agent end."""
         print_text("[on_agent_finish] " + finish.return_values['output'] + "\n", color='green', end="\n")
         print_text("[on_agent_finish] " + finish.return_values['output'] + "\n", color='green', end="\n")
 
 
+    @property
+    def ignore_llm(self) -> bool:
+        """Whether to ignore LLM callbacks."""
+        return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
+
+    @property
+    def ignore_chain(self) -> bool:
+        """Whether to ignore chain callbacks."""
+        return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
+
+    @property
+    def ignore_agent(self) -> bool:
+        """Whether to ignore agent callbacks."""
+        return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
+
+    @property
+    def ignore_chat_model(self) -> bool:
+        """Whether to ignore chat model callbacks."""
+        return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
+
 
 
 class DifyStreamingStdOutCallbackHandler(DifyStdOutCallbackHandler):
 class DifyStreamingStdOutCallbackHandler(DifyStdOutCallbackHandler):
     """Callback handler for streaming. Only works with LLMs that support streaming."""
     """Callback handler for streaming. Only works with LLMs that support streaming."""

+ 2 - 4
api/core/chain/chain_builder.py

@@ -1,7 +1,5 @@
 from typing import Optional
 from typing import Optional
 
 
-from langchain.callbacks import CallbackManager
-
 from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
 from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
 from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
 from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
 from core.chain.tool_chain import ToolChain
 from core.chain.tool_chain import ToolChain
@@ -14,7 +12,7 @@ class ChainBuilder:
             tool=tool,
             tool=tool,
             input_key=kwargs.get('input_key', 'input'),
             input_key=kwargs.get('input_key', 'input'),
             output_key=kwargs.get('output_key', 'tool_output'),
             output_key=kwargs.get('output_key', 'tool_output'),
-            callback_manager=CallbackManager([DifyStdOutCallbackHandler()])
+            callbacks=[DifyStdOutCallbackHandler()]
         )
         )
 
 
     @classmethod
     @classmethod
@@ -27,7 +25,7 @@ class ChainBuilder:
                 sensitive_words=sensitive_words.split(","),
                 sensitive_words=sensitive_words.split(","),
                 canned_response=tool_config.get("canned_response", ''),
                 canned_response=tool_config.get("canned_response", ''),
                 output_key="sensitive_word_avoidance_output",
                 output_key="sensitive_word_avoidance_output",
-                callback_manager=CallbackManager([DifyStdOutCallbackHandler()]),
+                callbacks=[DifyStdOutCallbackHandler()],
                 **kwargs
                 **kwargs
             )
             )
 
 

+ 6 - 4
api/core/chain/llm_router_chain.py

@@ -1,15 +1,16 @@
 """Base classes for LLM-powered router chains."""
 """Base classes for LLM-powered router chains."""
 from __future__ import annotations
 from __future__ import annotations
 
 
-import json
 from typing import Any, Dict, List, Optional, Type, cast, NamedTuple
 from typing import Any, Dict, List, Optional, Type, cast, NamedTuple
 
 
+from langchain.base_language import BaseLanguageModel
+from langchain.callbacks.manager import CallbackManagerForChainRun
 from langchain.chains.base import Chain
 from langchain.chains.base import Chain
 from pydantic import root_validator
 from pydantic import root_validator
 
 
 from langchain.chains import LLMChain
 from langchain.chains import LLMChain
 from langchain.prompts import BasePromptTemplate
 from langchain.prompts import BasePromptTemplate
-from langchain.schema import BaseOutputParser, OutputParserException, BaseLanguageModel
+from langchain.schema import BaseOutputParser, OutputParserException
 
 
 from libs.json_in_md_parser import parse_and_check_json_markdown
 from libs.json_in_md_parser import parse_and_check_json_markdown
 
 
@@ -51,8 +52,9 @@ class LLMRouterChain(Chain):
             raise ValueError
             raise ValueError
 
 
     def _call(
     def _call(
-        self,
-        inputs: Dict[str, Any]
+            self,
+            inputs: Dict[str, Any],
+            run_manager: Optional[CallbackManagerForChainRun] = None,
     ) -> Dict[str, Any]:
     ) -> Dict[str, Any]:
         output = cast(
         output = cast(
             Dict[str, Any],
             Dict[str, Any],

+ 10 - 8
api/core/chain/main_chain_builder.py

@@ -1,11 +1,9 @@
-from typing import Optional, List
+from typing import Optional, List, cast
 
 
-from langchain.callbacks import SharedCallbackManager, CallbackManager
 from langchain.chains import SequentialChain
 from langchain.chains import SequentialChain
 from langchain.chains.base import Chain
 from langchain.chains.base import Chain
 from langchain.memory.chat_memory import BaseChatMemory
 from langchain.memory.chat_memory import BaseChatMemory
 
 
-from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
 from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
 from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
 from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
 from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
 from core.chain.chain_builder import ChainBuilder
 from core.chain.chain_builder import ChainBuilder
@@ -18,6 +16,7 @@ from models.dataset import Dataset
 class MainChainBuilder:
 class MainChainBuilder:
     @classmethod
     @classmethod
     def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
     def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
+                                rest_tokens: int,
                                 conversation_message_task: ConversationMessageTask):
                                 conversation_message_task: ConversationMessageTask):
         first_input_key = "input"
         first_input_key = "input"
         final_output_key = "output"
         final_output_key = "output"
@@ -30,6 +29,7 @@ class MainChainBuilder:
         tool_chains, chains_output_key = cls.get_agent_chains(
         tool_chains, chains_output_key = cls.get_agent_chains(
             tenant_id=tenant_id,
             tenant_id=tenant_id,
             agent_mode=agent_mode,
             agent_mode=agent_mode,
+            rest_tokens=rest_tokens,
             memory=memory,
             memory=memory,
             conversation_message_task=conversation_message_task
             conversation_message_task=conversation_message_task
         )
         )
@@ -42,9 +42,8 @@ class MainChainBuilder:
             return None
             return None
 
 
         for chain in chains:
         for chain in chains:
-            # do not add handler into singleton callback manager
-            if not isinstance(chain.callback_manager, SharedCallbackManager):
-                chain.callback_manager.add_handler(chain_callback_handler)
+            chain = cast(Chain, chain)
+            chain.callbacks.append(chain_callback_handler)
 
 
         # build main chain
         # build main chain
         overall_chain = SequentialChain(
         overall_chain = SequentialChain(
@@ -57,7 +56,9 @@ class MainChainBuilder:
         return overall_chain
         return overall_chain
 
 
     @classmethod
     @classmethod
-    def get_agent_chains(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
+    def get_agent_chains(cls, tenant_id: str, agent_mode: dict,
+                         rest_tokens: int,
+                         memory: Optional[BaseChatMemory],
                          conversation_message_task: ConversationMessageTask):
                          conversation_message_task: ConversationMessageTask):
         # agent mode
         # agent mode
         chains = []
         chains = []
@@ -93,7 +94,8 @@ class MainChainBuilder:
                     tenant_id=tenant_id,
                     tenant_id=tenant_id,
                     datasets=datasets,
                     datasets=datasets,
                     conversation_message_task=conversation_message_task,
                     conversation_message_task=conversation_message_task,
-                    callback_manager=CallbackManager([DifyStdOutCallbackHandler()])
+                    rest_tokens=rest_tokens,
+                    callbacks=[DifyStdOutCallbackHandler()]
                 )
                 )
                 chains.append(multi_dataset_router_chain)
                 chains.append(multi_dataset_router_chain)
 
 

+ 62 - 16
api/core/chain/multi_dataset_router_chain.py

@@ -1,9 +1,9 @@
+import math
 from typing import Mapping, List, Dict, Any, Optional
 from typing import Mapping, List, Dict, Any, Optional
 
 
-from langchain import LLMChain, PromptTemplate, ConversationChain
-from langchain.callbacks import CallbackManager
+from langchain import PromptTemplate
+from langchain.callbacks.manager import CallbackManagerForChainRun
 from langchain.chains.base import Chain
 from langchain.chains.base import Chain
-from langchain.schema import BaseLanguageModel
 from pydantic import Extra
 from pydantic import Extra
 
 
 from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
 from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
@@ -11,10 +11,11 @@ from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHan
 from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser
 from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser
 from core.conversation_message_task import ConversationMessageTask
 from core.conversation_message_task import ConversationMessageTask
 from core.llm.llm_builder import LLMBuilder
 from core.llm.llm_builder import LLMBuilder
-from core.tool.dataset_tool_builder import DatasetToolBuilder
-from core.tool.llama_index_tool import EnhanceLlamaIndexTool
-from models.dataset import Dataset
+from core.tool.dataset_index_tool import DatasetTool
+from models.dataset import Dataset, DatasetProcessRule
 
 
+DEFAULT_K = 2
+CONTEXT_TOKENS_PERCENT = 0.3
 MULTI_PROMPT_ROUTER_TEMPLATE = """
 MULTI_PROMPT_ROUTER_TEMPLATE = """
 Given a raw text input to a language model select the model prompt best suited for \
 Given a raw text input to a language model select the model prompt best suited for \
 the input. You will be given the names of the available prompts and a description of \
 the input. You will be given the names of the available prompts and a description of \
@@ -52,7 +53,7 @@ class MultiDatasetRouterChain(Chain):
 
 
     router_chain: LLMRouterChain
     router_chain: LLMRouterChain
     """Chain for deciding a destination chain and the input to it."""
     """Chain for deciding a destination chain and the input to it."""
-    dataset_tools: Mapping[str, EnhanceLlamaIndexTool]
+    dataset_tools: Mapping[str, DatasetTool]
     """Map of name to candidate chains that inputs can be routed to."""
     """Map of name to candidate chains that inputs can be routed to."""
 
 
     class Config:
     class Config:
@@ -79,41 +80,56 @@ class MultiDatasetRouterChain(Chain):
             tenant_id: str,
             tenant_id: str,
             datasets: List[Dataset],
             datasets: List[Dataset],
             conversation_message_task: ConversationMessageTask,
             conversation_message_task: ConversationMessageTask,
+            rest_tokens: int,
             **kwargs: Any,
             **kwargs: Any,
     ):
     ):
         """Convenience constructor for instantiating from destination prompts."""
         """Convenience constructor for instantiating from destination prompts."""
-        llm_callback_manager = CallbackManager([DifyStdOutCallbackHandler()])
         llm = LLMBuilder.to_llm(
         llm = LLMBuilder.to_llm(
             tenant_id=tenant_id,
             tenant_id=tenant_id,
             model_name='gpt-3.5-turbo',
             model_name='gpt-3.5-turbo',
             temperature=0,
             temperature=0,
             max_tokens=1024,
             max_tokens=1024,
-            callback_manager=llm_callback_manager
+            callbacks=[DifyStdOutCallbackHandler()]
         )
         )
 
 
-        destinations = ["{}: {}".format(d.id, d.description.replace('\n', ' ') if d.description
+        destinations = ["[[{}]]: {}".format(d.id, d.description.replace('\n', ' ') if d.description
                         else ('useful for when you want to answer queries about the ' + d.name))
                         else ('useful for when you want to answer queries about the ' + d.name))
                         for d in datasets]
                         for d in datasets]
         destinations_str = "\n".join(destinations)
         destinations_str = "\n".join(destinations)
         router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
         router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
             destinations=destinations_str
             destinations=destinations_str
         )
         )
+
         router_prompt = PromptTemplate(
         router_prompt = PromptTemplate(
             template=router_template,
             template=router_template,
             input_variables=["input"],
             input_variables=["input"],
             output_parser=RouterOutputParser(),
             output_parser=RouterOutputParser(),
         )
         )
+
         router_chain = LLMRouterChain.from_llm(llm, router_prompt)
         router_chain = LLMRouterChain.from_llm(llm, router_prompt)
         dataset_tools = {}
         dataset_tools = {}
         for dataset in datasets:
         for dataset in datasets:
-            dataset_tool = DatasetToolBuilder.build_dataset_tool(
+            # fulfill description when it is empty
+            if dataset.available_document_count == 0 or dataset.available_document_count == 0:
+                continue
+
+            description = dataset.description
+            if not description:
+                description = 'useful for when you want to answer queries about the ' + dataset.name
+
+            k = cls._dynamic_calc_retrieve_k(dataset, rest_tokens)
+            if k == 0:
+                continue
+
+            dataset_tool = DatasetTool(
+                name=f"dataset-{dataset.id}",
+                description=description,
+                k=k,
                 dataset=dataset,
                 dataset=dataset,
-                response_mode='no_synthesizer',  # "compact"
-                callback_handler=DatasetToolCallbackHandler(conversation_message_task)
+                callbacks=[DatasetToolCallbackHandler(conversation_message_task), DifyStdOutCallbackHandler()]
             )
             )
 
 
-            if dataset_tool:
-                dataset_tools[dataset.id] = dataset_tool
+            dataset_tools[str(dataset.id)] = dataset_tool
 
 
         return cls(
         return cls(
             router_chain=router_chain,
             router_chain=router_chain,
@@ -121,9 +137,39 @@ class MultiDatasetRouterChain(Chain):
             **kwargs,
             **kwargs,
         )
         )
 
 
+    @classmethod
+    def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
+        processing_rule = dataset.latest_process_rule
+        if not processing_rule:
+            return DEFAULT_K
+
+        if processing_rule.mode == "custom":
+            rules = processing_rule.rules_dict
+            if not rules:
+                return DEFAULT_K
+
+            segmentation = rules["segmentation"]
+            segment_max_tokens = segmentation["max_tokens"]
+        else:
+            segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
+
+        # when rest_tokens is less than default context tokens
+        if rest_tokens < segment_max_tokens * DEFAULT_K:
+            return rest_tokens // segment_max_tokens
+
+        context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
+
+        # when context_limit_tokens is less than default context tokens, use default_k
+        if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
+            return DEFAULT_K
+
+        # Expand the k value when there's still some room left in the 30% rest tokens space
+        return context_limit_tokens // segment_max_tokens
+
     def _call(
     def _call(
         self,
         self,
-        inputs: Dict[str, Any]
+        inputs: Dict[str, Any],
+        run_manager: Optional[CallbackManagerForChainRun] = None,
     ) -> Dict[str, Any]:
     ) -> Dict[str, Any]:
         if len(self.dataset_tools) == 0:
         if len(self.dataset_tools) == 0:
             return {"text": ''}
             return {"text": ''}

+ 7 - 2
api/core/chain/sensitive_word_avoidance_chain.py

@@ -1,5 +1,6 @@
-from typing import List, Dict
+from typing import List, Dict, Optional, Any
 
 
+from langchain.callbacks.manager import CallbackManagerForChainRun
 from langchain.chains.base import Chain
 from langchain.chains.base import Chain
 
 
 
 
@@ -36,7 +37,11 @@ class SensitiveWordAvoidanceChain(Chain):
                 return self.canned_response
                 return self.canned_response
         return text
         return text
 
 
-    def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
+    def _call(
+            self,
+            inputs: Dict[str, Any],
+            run_manager: Optional[CallbackManagerForChainRun] = None,
+    ) -> Dict[str, Any]:
         text = inputs[self.input_key]
         text = inputs[self.input_key]
         output = self._check_sensitive_word(text)
         output = self._check_sensitive_word(text)
         return {self.output_key: output}
         return {self.output_key: output}

+ 12 - 3
api/core/chain/tool_chain.py

@@ -1,5 +1,6 @@
-from typing import List, Dict
+from typing import List, Dict, Optional, Any
 
 
+from langchain.callbacks.manager import CallbackManagerForChainRun, AsyncCallbackManagerForChainRun
 from langchain.chains.base import Chain
 from langchain.chains.base import Chain
 from langchain.tools import BaseTool
 from langchain.tools import BaseTool
 
 
@@ -30,12 +31,20 @@ class ToolChain(Chain):
         """
         """
         return [self.output_key]
         return [self.output_key]
 
 
-    def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
+    def _call(
+            self,
+            inputs: Dict[str, Any],
+            run_manager: Optional[CallbackManagerForChainRun] = None,
+    ) -> Dict[str, Any]:
         input = inputs[self.input_key]
         input = inputs[self.input_key]
         output = self.tool.run(input, self.verbose)
         output = self.tool.run(input, self.verbose)
         return {self.output_key: output}
         return {self.output_key: output}
 
 
-    async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
+    async def _acall(
+            self,
+            inputs: Dict[str, Any],
+            run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
+    ) -> Dict[str, Any]:
         """Run the logic of this chain and return the output."""
         """Run the logic of this chain and return the output."""
         input = inputs[self.input_key]
         input = inputs[self.input_key]
         output = await self.tool.arun(input, self.verbose)
         output = await self.tool.arun(input, self.verbose)

+ 42 - 17
api/core/completion.py

@@ -1,17 +1,18 @@
 import logging
 import logging
 from typing import Optional, List, Union, Tuple
 from typing import Optional, List, Union, Tuple
 
 
-from langchain.callbacks import CallbackManager
+from langchain.base_language import BaseLanguageModel
+from langchain.callbacks.base import BaseCallbackHandler
 from langchain.chat_models.base import BaseChatModel
 from langchain.chat_models.base import BaseChatModel
 from langchain.llms import BaseLLM
 from langchain.llms import BaseLLM
-from langchain.schema import BaseMessage, BaseLanguageModel, HumanMessage
+from langchain.schema import BaseMessage, HumanMessage
 from requests.exceptions import ChunkedEncodingError
 from requests.exceptions import ChunkedEncodingError
 
 
 from core.constant import llm_constant
 from core.constant import llm_constant
 from core.callback_handler.llm_callback_handler import LLMCallbackHandler
 from core.callback_handler.llm_callback_handler import LLMCallbackHandler
 from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
 from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
     DifyStdOutCallbackHandler
     DifyStdOutCallbackHandler
-from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, PubHandler
+from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
 from core.llm.error import LLMBadRequestError
 from core.llm.error import LLMBadRequestError
 from core.llm.llm_builder import LLMBuilder
 from core.llm.llm_builder import LLMBuilder
 from core.chain.main_chain_builder import MainChainBuilder
 from core.chain.main_chain_builder import MainChainBuilder
@@ -34,8 +35,6 @@ class Completion:
         """
         """
         errors: ProviderTokenNotInitError
         errors: ProviderTokenNotInitError
         """
         """
-        cls.validate_query_tokens(app.tenant_id, app_model_config, query)
-
         memory = None
         memory = None
         if conversation:
         if conversation:
             # get memory of conversation (read-only)
             # get memory of conversation (read-only)
@@ -48,6 +47,14 @@ class Completion:
 
 
             inputs = conversation.inputs
             inputs = conversation.inputs
 
 
+        rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
+            mode=app.mode,
+            tenant_id=app.tenant_id,
+            app_model_config=app_model_config,
+            query=query,
+            inputs=inputs
+        )
+
         conversation_message_task = ConversationMessageTask(
         conversation_message_task = ConversationMessageTask(
             task_id=task_id,
             task_id=task_id,
             app=app,
             app=app,
@@ -64,6 +71,7 @@ class Completion:
         main_chain = MainChainBuilder.to_langchain_components(
         main_chain = MainChainBuilder.to_langchain_components(
             tenant_id=app.tenant_id,
             tenant_id=app.tenant_id,
             agent_mode=app_model_config.agent_mode_dict,
             agent_mode=app_model_config.agent_mode_dict,
+            rest_tokens=rest_tokens_for_context_and_memory,
             memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None,
             memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None,
             conversation_message_task=conversation_message_task
             conversation_message_task=conversation_message_task
         )
         )
@@ -115,7 +123,7 @@ class Completion:
             memory=memory
             memory=memory
         )
         )
 
 
-        final_llm.callback_manager = cls.get_llm_callback_manager(final_llm, streaming, conversation_message_task)
+        final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
 
 
         cls.recale_llm_max_tokens(
         cls.recale_llm_max_tokens(
             final_llm=final_llm,
             final_llm=final_llm,
@@ -247,16 +255,14 @@ And answer according to the language of the user's question.
             return messages, ['\nHuman:']
             return messages, ['\nHuman:']
 
 
     @classmethod
     @classmethod
-    def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
-                                 streaming: bool,
-                                 conversation_message_task: ConversationMessageTask) -> CallbackManager:
+    def get_llm_callbacks(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
+                          streaming: bool,
+                          conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]:
         llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
         llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
         if streaming:
         if streaming:
-            callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
+            return [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
         else:
         else:
-            callback_handlers = [llm_callback_handler, DifyStdOutCallbackHandler()]
-
-        return CallbackManager(callback_handlers)
+            return [llm_callback_handler, DifyStdOutCallbackHandler()]
 
 
     @classmethod
     @classmethod
     def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
     def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
@@ -293,7 +299,8 @@ And answer according to the language of the user's question.
         return memory
         return memory
 
 
     @classmethod
     @classmethod
-    def validate_query_tokens(cls, tenant_id: str, app_model_config: AppModelConfig, query: str):
+    def get_validate_rest_tokens(cls, mode: str, tenant_id: str, app_model_config: AppModelConfig,
+                                 query: str, inputs: dict) -> int:
         llm = LLMBuilder.to_llm_from_model(
         llm = LLMBuilder.to_llm_from_model(
             tenant_id=tenant_id,
             tenant_id=tenant_id,
             model=app_model_config.model_dict
             model=app_model_config.model_dict
@@ -302,8 +309,26 @@ And answer according to the language of the user's question.
         model_limited_tokens = llm_constant.max_context_token_length[llm.model_name]
         model_limited_tokens = llm_constant.max_context_token_length[llm.model_name]
         max_tokens = llm.max_tokens
         max_tokens = llm.max_tokens
 
 
-        if model_limited_tokens - max_tokens - llm.get_num_tokens(query) < 0:
-            raise LLMBadRequestError("Query is too long")
+        # get prompt without memory and context
+        prompt, _ = cls.get_main_llm_prompt(
+            mode=mode,
+            llm=llm,
+            pre_prompt=app_model_config.pre_prompt,
+            query=query,
+            inputs=inputs,
+            chain_output=None,
+            memory=None
+        )
+
+        prompt_tokens = llm.get_num_tokens(prompt) if isinstance(prompt, str) \
+            else llm.get_num_tokens_from_messages(prompt)
+
+        rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
+        if rest_tokens < 0:
+            raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
+                                     "or shrink the max token, or switch to a llm with a larger token limit size.")
+
+        return rest_tokens
 
 
     @classmethod
     @classmethod
     def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI],
     def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI],
@@ -360,7 +385,7 @@ And answer according to the language of the user's question.
             streaming=streaming
             streaming=streaming
         )
         )
 
 
-        llm.callback_manager = cls.get_llm_callback_manager(llm, streaming, conversation_message_task)
+        llm.callbacks = cls.get_llm_callbacks(llm, streaming, conversation_message_task)
 
 
         cls.recale_llm_max_tokens(
         cls.recale_llm_max_tokens(
             final_llm=llm,
             final_llm=llm,

+ 4 - 4
api/core/conversation_message_task.py

@@ -293,12 +293,12 @@ class PubHandler:
         if not user:
         if not user:
             raise ValueError("user is required")
             raise ValueError("user is required")
 
 
-        user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id
+        user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
         return "generate_result:{}-{}".format(user_str, task_id)
         return "generate_result:{}-{}".format(user_str, task_id)
 
 
     @classmethod
     @classmethod
     def generate_stopped_cache_key(cls, user: Union[Account | EndUser], task_id: str):
     def generate_stopped_cache_key(cls, user: Union[Account | EndUser], task_id: str):
-        user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id
+        user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
         return "generate_result_stopped:{}-{}".format(user_str, task_id)
         return "generate_result_stopped:{}-{}".format(user_str, task_id)
 
 
     def pub_text(self, text: str):
     def pub_text(self, text: str):
@@ -306,10 +306,10 @@ class PubHandler:
             'event': 'message',
             'event': 'message',
             'data': {
             'data': {
                 'task_id': self._task_id,
                 'task_id': self._task_id,
-                'message_id': self._message.id,
+                'message_id': str(self._message.id),
                 'text': text,
                 'text': text,
                 'mode': self._conversation.mode,
                 'mode': self._conversation.mode,
-                'conversation_id': self._conversation.id
+                'conversation_id': str(self._conversation.id)
             }
             }
         }
         }
 
 

+ 43 - 0
api/core/data_loader/file_extractor.py

@@ -0,0 +1,43 @@
+import tempfile
+from pathlib import Path
+from typing import List, Union
+
+from langchain.document_loaders import TextLoader, Docx2txtLoader
+from langchain.schema import Document
+
+from core.data_loader.loader.csv import CSVLoader
+from core.data_loader.loader.excel import ExcelLoader
+from core.data_loader.loader.html import HTMLLoader
+from core.data_loader.loader.markdown import MarkdownLoader
+from core.data_loader.loader.pdf import PdfLoader
+from extensions.ext_storage import storage
+from models.model import UploadFile
+
+
+class FileExtractor:
+    @classmethod
+    def load(cls, upload_file: UploadFile, return_text: bool = False) -> Union[List[Document] | str]:
+        with tempfile.TemporaryDirectory() as temp_dir:
+            suffix = Path(upload_file.key).suffix
+            file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
+            storage.download(upload_file.key, file_path)
+
+            input_file = Path(file_path)
+            delimiter = '\n'
+            if input_file.suffix == '.xlsx':
+                loader = ExcelLoader(file_path)
+            elif input_file.suffix == '.pdf':
+                loader = PdfLoader(file_path, upload_file=upload_file)
+            elif input_file.suffix in ['.md', '.markdown']:
+                loader = MarkdownLoader(file_path, autodetect_encoding=True)
+            elif input_file.suffix in ['.htm', '.html']:
+                loader = HTMLLoader(file_path)
+            elif input_file.suffix == '.docx':
+                loader = Docx2txtLoader(file_path)
+            elif input_file.suffix == '.csv':
+                loader = CSVLoader(file_path, autodetect_encoding=True)
+            else:
+                # txt
+                loader = TextLoader(file_path, autodetect_encoding=True)
+
+            return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load()

+ 67 - 0
api/core/data_loader/loader/csv.py

@@ -0,0 +1,67 @@
+import logging
+from typing import Optional, Dict, List
+
+from langchain.document_loaders import CSVLoader as LCCSVLoader
+from langchain.document_loaders.helpers import detect_file_encodings
+
+from models.dataset import Document
+
+logger = logging.getLogger(__name__)
+
+
+class CSVLoader(LCCSVLoader):
+    def __init__(
+            self,
+            file_path: str,
+            source_column: Optional[str] = None,
+            csv_args: Optional[Dict] = None,
+            encoding: Optional[str] = None,
+            autodetect_encoding: bool = True,
+    ):
+        self.file_path = file_path
+        self.source_column = source_column
+        self.encoding = encoding
+        self.csv_args = csv_args or {}
+        self.autodetect_encoding = autodetect_encoding
+
+    def load(self) -> List[Document]:
+        """Load data into document objects."""
+        try:
+            with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
+                docs = self._read_from_file(csvfile)
+        except UnicodeDecodeError as e:
+            if self.autodetect_encoding:
+                detected_encodings = detect_file_encodings(self.file_path)
+                for encoding in detected_encodings:
+                    logger.debug("Trying encoding: ", encoding.encoding)
+                    try:
+                        with open(self.file_path, newline="", encoding=encoding.encoding) as csvfile:
+                            docs = self._read_from_file(csvfile)
+                        break
+                    except UnicodeDecodeError:
+                        continue
+            else:
+                raise RuntimeError(f"Error loading {self.file_path}") from e
+
+        return docs
+
+    def _read_from_file(self, csvfile):
+        docs = []
+        csv_reader = csv.DictReader(csvfile, **self.csv_args)  # type: ignore
+        for i, row in enumerate(csv_reader):
+            content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
+            try:
+                source = (
+                    row[self.source_column]
+                    if self.source_column is not None
+                    else ''
+                )
+            except KeyError:
+                raise ValueError(
+                    f"Source column '{self.source_column}' not found in CSV file."
+                )
+            metadata = {"source": source, "row": i}
+            doc = Document(page_content=content, metadata=metadata)
+            docs.append(doc)
+
+        return docs

+ 43 - 0
api/core/data_loader/loader/excel.py

@@ -0,0 +1,43 @@
+import json
+import logging
+from typing import List
+
+from langchain.document_loaders.base import BaseLoader
+from langchain.schema import Document
+from openpyxl.reader.excel import load_workbook
+
+logger = logging.getLogger(__name__)
+
+
+class ExcelLoader(BaseLoader):
+    """Load xlxs files.
+
+
+    Args:
+        file_path: Path to the file to load.
+    """
+
+    def __init__(
+        self,
+        file_path: str
+    ):
+        """Initialize with file path."""
+        self._file_path = file_path
+
+    def load(self) -> List[Document]:
+        data = []
+        keys = []
+        wb = load_workbook(filename=self._file_path, read_only=True)
+        # loop over all sheets
+        for sheet in wb:
+            for row in sheet.iter_rows(values_only=True):
+                if all(v is None for v in row):
+                    continue
+                if keys == []:
+                    keys = list(map(str, row))
+                else:
+                    row_dict = dict(zip(keys, row))
+                    row_dict = {k: v for k, v in row_dict.items() if v}
+                    data.append(json.dumps(row_dict, ensure_ascii=False))
+
+        return [Document(page_content='\n\n'.join(data))]

+ 35 - 0
api/core/data_loader/loader/html.py

@@ -0,0 +1,35 @@
+import logging
+from typing import List
+
+from bs4 import BeautifulSoup
+from langchain.document_loaders.base import BaseLoader
+from langchain.schema import Document
+
+logger = logging.getLogger(__name__)
+
+
+class HTMLLoader(BaseLoader):
+    """Load html files.
+
+
+    Args:
+        file_path: Path to the file to load.
+    """
+
+    def __init__(
+        self,
+        file_path: str
+    ):
+        """Initialize with file path."""
+        self._file_path = file_path
+
+    def load(self) -> List[Document]:
+        return [Document(page_content=self._load_as_text())]
+
+    def _load_as_text(self) -> str:
+        with open(self._file_path, "rb") as fp:
+            soup = BeautifulSoup(fp, 'html.parser')
+            text = soup.get_text()
+            text = text.strip() if text else ''
+
+        return text

+ 134 - 0
api/core/data_loader/loader/markdown.py

@@ -0,0 +1,134 @@
+import logging
+import re
+from typing import Optional, List, Tuple, cast
+
+from langchain.document_loaders.base import BaseLoader
+from langchain.document_loaders.helpers import detect_file_encodings
+from langchain.schema import Document
+
+logger = logging.getLogger(__name__)
+
+
+class MarkdownLoader(BaseLoader):
+    """Load md files.
+
+
+    Args:
+        file_path: Path to the file to load.
+
+        remove_hyperlinks: Whether to remove hyperlinks from the text.
+
+        remove_images: Whether to remove images from the text.
+
+        encoding: File encoding to use. If `None`, the file will be loaded
+        with the default system encoding.
+
+        autodetect_encoding: Whether to try to autodetect the file encoding
+            if the specified encoding fails.
+    """
+
+    def __init__(
+        self,
+        file_path: str,
+        remove_hyperlinks: bool = True,
+        remove_images: bool = True,
+        encoding: Optional[str] = None,
+        autodetect_encoding: bool = True,
+    ):
+        """Initialize with file path."""
+        self._file_path = file_path
+        self._remove_hyperlinks = remove_hyperlinks
+        self._remove_images = remove_images
+        self._encoding = encoding
+        self._autodetect_encoding = autodetect_encoding
+
+    def load(self) -> List[Document]:
+        tups = self.parse_tups(self._file_path)
+        documents = []
+        for header, value in tups:
+            value = value.strip()
+            if header is None:
+                documents.append(Document(page_content=value))
+            else:
+                documents.append(Document(page_content=f"\n\n{header}\n{value}"))
+
+        return documents
+
+    def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]:
+        """Convert a markdown file to a dictionary.
+
+        The keys are the headers and the values are the text under each header.
+
+        """
+        markdown_tups: List[Tuple[Optional[str], str]] = []
+        lines = markdown_text.split("\n")
+
+        current_header = None
+        current_text = ""
+
+        for line in lines:
+            header_match = re.match(r"^#+\s", line)
+            if header_match:
+                if current_header is not None:
+                    markdown_tups.append((current_header, current_text))
+
+                current_header = line
+                current_text = ""
+            else:
+                current_text += line + "\n"
+        markdown_tups.append((current_header, current_text))
+
+        if current_header is not None:
+            # pass linting, assert keys are defined
+            markdown_tups = [
+                (re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value))
+                for key, value in markdown_tups
+            ]
+        else:
+            markdown_tups = [
+                (key, re.sub("\n", "", value)) for key, value in markdown_tups
+            ]
+
+        return markdown_tups
+
+    def remove_images(self, content: str) -> str:
+        """Get a dictionary of a markdown file from its path."""
+        pattern = r"!{1}\[\[(.*)\]\]"
+        content = re.sub(pattern, "", content)
+        return content
+
+    def remove_hyperlinks(self, content: str) -> str:
+        """Get a dictionary of a markdown file from its path."""
+        pattern = r"\[(.*?)\]\((.*?)\)"
+        content = re.sub(pattern, r"\1", content)
+        return content
+
+    def parse_tups(self, filepath: str) -> List[Tuple[Optional[str], str]]:
+        """Parse file into tuples."""
+        content = ""
+        try:
+            with open(filepath, "r", encoding=self._encoding) as f:
+                content = f.read()
+        except UnicodeDecodeError as e:
+            if self._autodetect_encoding:
+                detected_encodings = detect_file_encodings(filepath)
+                for encoding in detected_encodings:
+                    logger.debug("Trying encoding: ", encoding.encoding)
+                    try:
+                        with open(filepath, encoding=encoding.encoding) as f:
+                            content = f.read()
+                        break
+                    except UnicodeDecodeError:
+                        continue
+            else:
+                raise RuntimeError(f"Error loading {filepath}") from e
+        except Exception as e:
+            raise RuntimeError(f"Error loading {filepath}") from e
+
+        if self._remove_hyperlinks:
+            content = self.remove_hyperlinks(content)
+
+        if self._remove_images:
+            content = self.remove_images(content)
+
+        return self.markdown_to_tups(content)

+ 236 - 236
api/core/data_source/notion.py → api/core/data_loader/loader/notion.py

@@ -1,68 +1,162 @@
-"""Notion reader."""
 import json
 import json
 import logging
 import logging
-import os
-from datetime import datetime
-from typing import Any, Dict, List, Optional
+from typing import List, Dict, Any, Optional
 
 
-import requests  # type: ignore
+import requests
+from flask import current_app
+from langchain.document_loaders.base import BaseLoader
+from langchain.schema import Document
 
 
-from llama_index.readers.base import BaseReader
-from llama_index.readers.schema.base import Document
+from extensions.ext_database import db
+from models.dataset import Document as DocumentModel
+from models.source import DataSourceBinding
+
+logger = logging.getLogger(__name__)
 
 
-INTEGRATION_TOKEN_NAME = "NOTION_INTEGRATION_TOKEN"
 BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"
 BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"
 DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query"
 DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query"
 SEARCH_URL = "https://api.notion.com/v1/search"
 SEARCH_URL = "https://api.notion.com/v1/search"
 RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}"
 RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}"
 RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
 RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
 HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']
 HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']
-logger = logging.getLogger(__name__)
-
 
 
-# TODO: Notion DB reader coming soon!
-class NotionPageReader(BaseReader):
-    """Notion Page reader.
 
 
-    Reads a set of Notion pages.
-
-    Args:
-        integration_token (str): Notion integration token.
-
-    """
-
-    def __init__(self, integration_token: Optional[str] = None) -> None:
-        """Initialize with parameters."""
-        if integration_token is None:
-            integration_token = os.getenv(INTEGRATION_TOKEN_NAME)
+class NotionLoader(BaseLoader):
+    def __init__(
+            self,
+            notion_access_token: str,
+            notion_workspace_id: str,
+            notion_obj_id: str,
+            notion_page_type: str,
+            document_model: Optional[DocumentModel] = None
+    ):
+        self._document_model = document_model
+        self._notion_workspace_id = notion_workspace_id
+        self._notion_obj_id = notion_obj_id
+        self._notion_page_type = notion_page_type
+        self._notion_access_token = notion_access_token
+
+        if not self._notion_access_token:
+            integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN')
             if integration_token is None:
             if integration_token is None:
                 raise ValueError(
                 raise ValueError(
                     "Must specify `integration_token` or set environment "
                     "Must specify `integration_token` or set environment "
                     "variable `NOTION_INTEGRATION_TOKEN`."
                     "variable `NOTION_INTEGRATION_TOKEN`."
                 )
                 )
-        self.token = integration_token
-        self.headers = {
-            "Authorization": "Bearer " + self.token,
-            "Content-Type": "application/json",
-            "Notion-Version": "2022-06-28",
-        }
 
 
-    def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
-        """Read a block."""
-        done = False
+            self._notion_access_token = integration_token
+
+    @classmethod
+    def from_document(cls, document_model: DocumentModel):
+        data_source_info = document_model.data_source_info_dict
+        if not data_source_info or 'notion_page_id' not in data_source_info \
+                or 'notion_workspace_id' not in data_source_info:
+            raise ValueError("no notion page found")
+
+        notion_workspace_id = data_source_info['notion_workspace_id']
+        notion_obj_id = data_source_info['notion_page_id']
+        notion_page_type = data_source_info['type']
+        notion_access_token = cls._get_access_token(document_model.tenant_id, notion_workspace_id)
+
+        return cls(
+            notion_access_token=notion_access_token,
+            notion_workspace_id=notion_workspace_id,
+            notion_obj_id=notion_obj_id,
+            notion_page_type=notion_page_type,
+            document_model=document_model
+        )
+
+    def load(self) -> List[Document]:
+        self.update_last_edited_time(
+            self._document_model
+        )
+
+        text_docs = self._load_data_as_documents(self._notion_obj_id, self._notion_page_type)
+
+        return text_docs
+
+    def _load_data_as_documents(
+            self, notion_obj_id: str, notion_page_type: str
+    ) -> List[Document]:
+        docs = []
+        if notion_page_type == 'database':
+            # get all the pages in the database
+            page_text = self._get_notion_database_data(notion_obj_id)
+            docs.append(Document(page_content=page_text))
+        elif notion_page_type == 'page':
+            page_text_list = self._get_notion_block_data(notion_obj_id)
+            for page_text in page_text_list:
+                docs.append(Document(page_content=page_text))
+        else:
+            raise ValueError("notion page type not supported")
+
+        return docs
+
+    def _get_notion_database_data(
+            self, database_id: str, query_dict: Dict[str, Any] = {}
+    ) -> str:
+        """Get all the pages from a Notion database."""
+        res = requests.post(
+            DATABASE_URL_TMPL.format(database_id=database_id),
+            headers={
+                "Authorization": "Bearer " + self._notion_access_token,
+                "Content-Type": "application/json",
+                "Notion-Version": "2022-06-28",
+            },
+            json=query_dict,
+        )
+
+        data = res.json()
+
+        database_content_list = []
+        if 'results' not in data or data["results"] is None:
+            return ""
+        for result in data["results"]:
+            properties = result['properties']
+            data = {}
+            for property_name, property_value in properties.items():
+                type = property_value['type']
+                if type == 'multi_select':
+                    value = []
+                    multi_select_list = property_value[type]
+                    for multi_select in multi_select_list:
+                        value.append(multi_select['name'])
+                elif type == 'rich_text' or type == 'title':
+                    if len(property_value[type]) > 0:
+                        value = property_value[type][0]['plain_text']
+                    else:
+                        value = ''
+                elif type == 'select' or type == 'status':
+                    if property_value[type]:
+                        value = property_value[type]['name']
+                    else:
+                        value = ''
+                else:
+                    value = property_value[type]
+                data[property_name] = value
+            database_content_list.append(json.dumps(data, ensure_ascii=False))
+
+        return "\n\n".join(database_content_list)
+
+    def _get_notion_block_data(self, page_id: str) -> List[str]:
         result_lines_arr = []
         result_lines_arr = []
-        cur_block_id = block_id
-        while not done:
+        cur_block_id = page_id
+        while True:
             block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
             block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
             query_dict: Dict[str, Any] = {}
             query_dict: Dict[str, Any] = {}
 
 
             res = requests.request(
             res = requests.request(
-                "GET", block_url, headers=self.headers, json=query_dict
+                "GET",
+                block_url,
+                headers={
+                    "Authorization": "Bearer " + self._notion_access_token,
+                    "Content-Type": "application/json",
+                    "Notion-Version": "2022-06-28",
+                },
+                json=query_dict
             )
             )
             data = res.json()
             data = res.json()
-            if 'results' not in data or data["results"] is None:
-                done = True
-                break
+            # current block's heading
             heading = ''
             heading = ''
             for result in data["results"]:
             for result in data["results"]:
                 result_type = result["type"]
                 result_type = result["type"]
@@ -71,6 +165,7 @@ class NotionPageReader(BaseReader):
                 if result_type == 'table':
                 if result_type == 'table':
                     result_block_id = result["id"]
                     result_block_id = result["id"]
                     text = self._read_table_rows(result_block_id)
                     text = self._read_table_rows(result_block_id)
+                    text += "\n\n"
                     result_lines_arr.append(text)
                     result_lines_arr.append(text)
                 else:
                 else:
                     if "rich_text" in result_obj:
                     if "rich_text" in result_obj:
@@ -78,91 +173,53 @@ class NotionPageReader(BaseReader):
                             # skip if doesn't have text object
                             # skip if doesn't have text object
                             if "text" in rich_text:
                             if "text" in rich_text:
                                 text = rich_text["text"]["content"]
                                 text = rich_text["text"]["content"]
-                                prefix = "\t" * num_tabs
-                                cur_result_text_arr.append(prefix + text)
+                                cur_result_text_arr.append(text)
                                 if result_type in HEADING_TYPE:
                                 if result_type in HEADING_TYPE:
                                     heading = text
                                     heading = text
+
                     result_block_id = result["id"]
                     result_block_id = result["id"]
                     has_children = result["has_children"]
                     has_children = result["has_children"]
                     block_type = result["type"]
                     block_type = result["type"]
                     if has_children and block_type != 'child_page':
                     if has_children and block_type != 'child_page':
                         children_text = self._read_block(
                         children_text = self._read_block(
-                            result_block_id, num_tabs=num_tabs + 1
+                            result_block_id, num_tabs=1
                         )
                         )
                         cur_result_text_arr.append(children_text)
                         cur_result_text_arr.append(children_text)
 
 
                     cur_result_text = "\n".join(cur_result_text_arr)
                     cur_result_text = "\n".join(cur_result_text_arr)
+                    cur_result_text += "\n\n"
                     if result_type in HEADING_TYPE:
                     if result_type in HEADING_TYPE:
                         result_lines_arr.append(cur_result_text)
                         result_lines_arr.append(cur_result_text)
                     else:
                     else:
                         result_lines_arr.append(f'{heading}\n{cur_result_text}')
                         result_lines_arr.append(f'{heading}\n{cur_result_text}')
 
 
             if data["next_cursor"] is None:
             if data["next_cursor"] is None:
-                done = True
-                break
-            else:
-                cur_block_id = data["next_cursor"]
-
-        result_lines = "\n".join(result_lines_arr)
-        return result_lines
-
-    def _read_table_rows(self, block_id: str) -> str:
-        """Read table rows."""
-        done = False
-        result_lines_arr = []
-        cur_block_id = block_id
-        while not done:
-            block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
-            query_dict: Dict[str, Any] = {}
-
-            res = requests.request(
-                "GET", block_url, headers=self.headers, json=query_dict
-            )
-            data = res.json()
-            # get table headers text
-            table_header_cell_texts = []
-            tabel_header_cells = data["results"][0]['table_row']['cells']
-            for tabel_header_cell in tabel_header_cells:
-                if tabel_header_cell:
-                    for table_header_cell_text in tabel_header_cell:
-                        text = table_header_cell_text["text"]["content"]
-                        table_header_cell_texts.append(text)
-            # get table columns text and format
-            results = data["results"]
-            for i in range(len(results)-1):
-                column_texts = []
-                tabel_column_cells = data["results"][i+1]['table_row']['cells']
-                for j in range(len(tabel_column_cells)):
-                    if tabel_column_cells[j]:
-                        for table_column_cell_text in tabel_column_cells[j]:
-                            column_text = table_column_cell_text["text"]["content"]
-                            column_texts.append(f'{table_header_cell_texts[j]}:{column_text}')
-
-                cur_result_text = "\n".join(column_texts)
-                result_lines_arr.append(cur_result_text)
-
-            if data["next_cursor"] is None:
-                done = True
                 break
                 break
             else:
             else:
                 cur_block_id = data["next_cursor"]
                 cur_block_id = data["next_cursor"]
+        return result_lines_arr
 
 
-        result_lines = "\n".join(result_lines_arr)
-        return result_lines
-    def _read_parent_blocks(self, block_id: str, num_tabs: int = 0) -> List[str]:
+    def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
         """Read a block."""
         """Read a block."""
-        done = False
         result_lines_arr = []
         result_lines_arr = []
         cur_block_id = block_id
         cur_block_id = block_id
-        while not done:
+        while True:
             block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
             block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
             query_dict: Dict[str, Any] = {}
             query_dict: Dict[str, Any] = {}
 
 
             res = requests.request(
             res = requests.request(
-                "GET", block_url, headers=self.headers, json=query_dict
+                "GET",
+                block_url,
+                headers={
+                    "Authorization": "Bearer " + self._notion_access_token,
+                    "Content-Type": "application/json",
+                    "Notion-Version": "2022-06-28",
+                },
+                json=query_dict
             )
             )
             data = res.json()
             data = res.json()
-            # current block's heading
+            if 'results' not in data or data["results"] is None:
+                break
             heading = ''
             heading = ''
             for result in data["results"]:
             for result in data["results"]:
                 result_type = result["type"]
                 result_type = result["type"]
@@ -171,7 +228,6 @@ class NotionPageReader(BaseReader):
                 if result_type == 'table':
                 if result_type == 'table':
                     result_block_id = result["id"]
                     result_block_id = result["id"]
                     text = self._read_table_rows(result_block_id)
                     text = self._read_table_rows(result_block_id)
-                    text += "\n\n"
                     result_lines_arr.append(text)
                     result_lines_arr.append(text)
                 else:
                 else:
                     if "rich_text" in result_obj:
                     if "rich_text" in result_obj:
@@ -179,10 +235,10 @@ class NotionPageReader(BaseReader):
                             # skip if doesn't have text object
                             # skip if doesn't have text object
                             if "text" in rich_text:
                             if "text" in rich_text:
                                 text = rich_text["text"]["content"]
                                 text = rich_text["text"]["content"]
-                                cur_result_text_arr.append(text)
+                                prefix = "\t" * num_tabs
+                                cur_result_text_arr.append(prefix + text)
                                 if result_type in HEADING_TYPE:
                                 if result_type in HEADING_TYPE:
                                     heading = text
                                     heading = text
-
                     result_block_id = result["id"]
                     result_block_id = result["id"]
                     has_children = result["has_children"]
                     has_children = result["has_children"]
                     block_type = result["type"]
                     block_type = result["type"]
@@ -193,177 +249,121 @@ class NotionPageReader(BaseReader):
                         cur_result_text_arr.append(children_text)
                         cur_result_text_arr.append(children_text)
 
 
                     cur_result_text = "\n".join(cur_result_text_arr)
                     cur_result_text = "\n".join(cur_result_text_arr)
-                    cur_result_text += "\n\n"
                     if result_type in HEADING_TYPE:
                     if result_type in HEADING_TYPE:
                         result_lines_arr.append(cur_result_text)
                         result_lines_arr.append(cur_result_text)
                     else:
                     else:
                         result_lines_arr.append(f'{heading}\n{cur_result_text}')
                         result_lines_arr.append(f'{heading}\n{cur_result_text}')
 
 
             if data["next_cursor"] is None:
             if data["next_cursor"] is None:
-                done = True
                 break
                 break
             else:
             else:
                 cur_block_id = data["next_cursor"]
                 cur_block_id = data["next_cursor"]
-        return result_lines_arr
-
-    def read_page(self, page_id: str) -> str:
-        """Read a page."""
-        return self._read_block(page_id)
-
-    def read_page_as_documents(self, page_id: str) -> List[str]:
-        """Read a page as documents."""
-        return self._read_parent_blocks(page_id)
-
-    def query_database_data(
-            self, database_id: str, query_dict: Dict[str, Any] = {}
-    ) -> str:
-        """Get all the pages from a Notion database."""
-        res = requests.post\
-                (
-            DATABASE_URL_TMPL.format(database_id=database_id),
-            headers=self.headers,
-            json=query_dict,
-        )
-        data = res.json()
-        database_content_list = []
-        if 'results' not in data or data["results"] is None:
-            return ""
-        for result in data["results"]:
-            properties = result['properties']
-            data = {}
-            for property_name, property_value in properties.items():
-                type = property_value['type']
-                if type == 'multi_select':
-                    value = []
-                    multi_select_list = property_value[type]
-                    for multi_select in multi_select_list:
-                        value.append(multi_select['name'])
-                elif type == 'rich_text' or type == 'title':
-                    if len(property_value[type]) > 0:
-                        value = property_value[type][0]['plain_text']
-                    else:
-                        value = ''
-                elif type == 'select' or type == 'status':
-                    if property_value[type]:
-                        value = property_value[type]['name']
-                    else:
-                        value = ''
-                else:
-                    value = property_value[type]
-                data[property_name] = value
-            database_content_list.append(json.dumps(data, ensure_ascii=False))
-
-        return "\n\n".join(database_content_list)
-
-    def query_database(
-            self, database_id: str, query_dict: Dict[str, Any] = {}
-    ) -> List[str]:
-        """Get all the pages from a Notion database."""
-        res = requests.post\
-                (
-            DATABASE_URL_TMPL.format(database_id=database_id),
-            headers=self.headers,
-            json=query_dict,
-        )
-        data = res.json()
-        page_ids = []
-        for result in data["results"]:
-            page_id = result["id"]
-            page_ids.append(page_id)
 
 
-        return page_ids
+        result_lines = "\n".join(result_lines_arr)
+        return result_lines
 
 
-    def search(self, query: str) -> List[str]:
-        """Search Notion page given a text query."""
+    def _read_table_rows(self, block_id: str) -> str:
+        """Read table rows."""
         done = False
         done = False
-        next_cursor: Optional[str] = None
-        page_ids = []
+        result_lines_arr = []
+        cur_block_id = block_id
         while not done:
         while not done:
-            query_dict = {
-                "query": query,
-            }
-            if next_cursor is not None:
-                query_dict["start_cursor"] = next_cursor
-            res = requests.post(SEARCH_URL, headers=self.headers, json=query_dict)
+            block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
+            query_dict: Dict[str, Any] = {}
+
+            res = requests.request(
+                "GET",
+                block_url,
+                headers={
+                    "Authorization": "Bearer " + self._notion_access_token,
+                    "Content-Type": "application/json",
+                    "Notion-Version": "2022-06-28",
+                },
+                json=query_dict
+            )
             data = res.json()
             data = res.json()
-            for result in data["results"]:
-                page_id = result["id"]
-                page_ids.append(page_id)
+            # get table headers text
+            table_header_cell_texts = []
+            tabel_header_cells = data["results"][0]['table_row']['cells']
+            for tabel_header_cell in tabel_header_cells:
+                if tabel_header_cell:
+                    for table_header_cell_text in tabel_header_cell:
+                        text = table_header_cell_text["text"]["content"]
+                        table_header_cell_texts.append(text)
+            # get table columns text and format
+            results = data["results"]
+            for i in range(len(results) - 1):
+                column_texts = []
+                tabel_column_cells = data["results"][i + 1]['table_row']['cells']
+                for j in range(len(tabel_column_cells)):
+                    if tabel_column_cells[j]:
+                        for table_column_cell_text in tabel_column_cells[j]:
+                            column_text = table_column_cell_text["text"]["content"]
+                            column_texts.append(f'{table_header_cell_texts[j]}:{column_text}')
+
+                cur_result_text = "\n".join(column_texts)
+                result_lines_arr.append(cur_result_text)
 
 
             if data["next_cursor"] is None:
             if data["next_cursor"] is None:
                 done = True
                 done = True
                 break
                 break
             else:
             else:
-                next_cursor = data["next_cursor"]
-        return page_ids
+                cur_block_id = data["next_cursor"]
 
 
-    def load_data(
-            self, page_ids: List[str] = [], database_id: Optional[str] = None
-    ) -> List[Document]:
-        """Load data from the input directory.
+        result_lines = "\n".join(result_lines_arr)
+        return result_lines
 
 
-        Args:
-            page_ids (List[str]): List of page ids to load.
+    def update_last_edited_time(self, document_model: DocumentModel):
+        if not document_model:
+            return
 
 
-        Returns:
-            List[Document]: List of documents.
+        last_edited_time = self.get_notion_last_edited_time()
+        data_source_info = document_model.data_source_info_dict
+        data_source_info['last_edited_time'] = last_edited_time
+        update_params = {
+            DocumentModel.data_source_info: json.dumps(data_source_info)
+        }
 
 
-        """
-        if not page_ids and not database_id:
-            raise ValueError("Must specify either `page_ids` or `database_id`.")
-        docs = []
-        if database_id is not None:
-            # get all the pages in the database
-            page_ids = self.query_database(database_id)
-            for page_id in page_ids:
-                page_text = self.read_page(page_id)
-                docs.append(Document(page_text))
-        else:
-            for page_id in page_ids:
-                page_text = self.read_page(page_id)
-                docs.append(Document(page_text))
+        DocumentModel.query.filter_by(id=document_model.id).update(update_params)
+        db.session.commit()
 
 
-        return docs
-
-    def load_data_as_documents(
-            self, page_ids: List[str] = [], database_id: Optional[str] = None
-    ) -> List[Document]:
-        if not page_ids and not database_id:
-            raise ValueError("Must specify either `page_ids` or `database_id`.")
-        docs = []
-        if database_id is not None:
-            # get all the pages in the database
-            page_text = self.query_database_data(database_id)
-            docs.append(Document(page_text))
+    def get_notion_last_edited_time(self) -> str:
+        obj_id = self._notion_obj_id
+        page_type = self._notion_page_type
+        if page_type == 'database':
+            retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=obj_id)
         else:
         else:
-            for page_id in page_ids:
-                page_text_list = self.read_page_as_documents(page_id)
-                for page_text in page_text_list:
-                    docs.append(Document(page_text))
+            retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id)
 
 
-        return docs
-
-    def get_page_last_edited_time(self, page_id: str) -> str:
-        retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=page_id)
         query_dict: Dict[str, Any] = {}
         query_dict: Dict[str, Any] = {}
 
 
         res = requests.request(
         res = requests.request(
-            "GET", retrieve_page_url, headers=self.headers, json=query_dict
+            "GET",
+            retrieve_page_url,
+            headers={
+                "Authorization": "Bearer " + self._notion_access_token,
+                "Content-Type": "application/json",
+                "Notion-Version": "2022-06-28",
+            },
+            json=query_dict
         )
         )
-        data = res.json()
-        return data["last_edited_time"]
-
-    def get_database_last_edited_time(self, database_id: str) -> str:
-        retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=database_id)
-        query_dict: Dict[str, Any] = {}
 
 
-        res = requests.request(
-            "GET", retrieve_page_url, headers=self.headers, json=query_dict
-        )
         data = res.json()
         data = res.json()
         return data["last_edited_time"]
         return data["last_edited_time"]
 
 
+    @classmethod
+    def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
+        data_source_binding = DataSourceBinding.query.filter(
+            db.and_(
+                DataSourceBinding.tenant_id == tenant_id,
+                DataSourceBinding.provider == 'notion',
+                DataSourceBinding.disabled == False,
+                DataSourceBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"'
+            )
+        ).first()
+
+        if not data_source_binding:
+            raise Exception(f'No notion data source binding found for tenant {tenant_id} '
+                            f'and notion workspace {notion_workspace_id}')
 
 
-if __name__ == "__main__":
-    reader = NotionPageReader()
-    logger.info(reader.search("What I"))
+        return data_source_binding.access_token

+ 55 - 0
api/core/data_loader/loader/pdf.py

@@ -0,0 +1,55 @@
+import logging
+from typing import List, Optional
+
+from langchain.document_loaders import PyPDFium2Loader
+from langchain.document_loaders.base import BaseLoader
+from langchain.schema import Document
+
+from extensions.ext_storage import storage
+from models.model import UploadFile
+
+logger = logging.getLogger(__name__)
+
+
+class PdfLoader(BaseLoader):
+    """Load pdf files.
+
+
+    Args:
+        file_path: Path to the file to load.
+    """
+
+    def __init__(
+        self,
+        file_path: str,
+        upload_file: Optional[UploadFile] = None
+    ):
+        """Initialize with file path."""
+        self._file_path = file_path
+        self._upload_file = upload_file
+
+    def load(self) -> List[Document]:
+        plaintext_file_key = ''
+        plaintext_file_exists = False
+        if self._upload_file:
+            if self._upload_file.hash:
+                plaintext_file_key = 'upload_files/' + self._upload_file.tenant_id + '/' \
+                                     + self._upload_file.hash + '.0625.plaintext'
+                try:
+                    text = storage.load(plaintext_file_key).decode('utf-8')
+                    plaintext_file_exists = True
+                    return [Document(page_content=text)]
+                except FileNotFoundError:
+                    pass
+        documents = PyPDFium2Loader(file_path=self._file_path).load()
+        text_list = []
+        for document in documents:
+            text_list.append(document.page_content)
+        text = "\n\n".join(text_list)
+
+        # save plaintext file for caching
+        if not plaintext_file_exists and plaintext_file_key:
+            storage.save(plaintext_file_key, text.encode('utf-8'))
+
+        return documents
+

+ 35 - 45
api/core/docstore/dataset_docstore.py

@@ -1,10 +1,6 @@
 from typing import Any, Dict, Optional, Sequence
 from typing import Any, Dict, Optional, Sequence
 
 
-import tiktoken
-from llama_index.data_structs import Node
-from llama_index.docstore.types import BaseDocumentStore
-from llama_index.docstore.utils import json_to_doc
-from llama_index.schema import BaseDocument
+from langchain.schema import Document
 from sqlalchemy import func
 from sqlalchemy import func
 
 
 from core.llm.token_calculator import TokenCalculator
 from core.llm.token_calculator import TokenCalculator
@@ -12,7 +8,7 @@ from extensions.ext_database import db
 from models.dataset import Dataset, DocumentSegment
 from models.dataset import Dataset, DocumentSegment
 
 
 
 
-class DatesetDocumentStore(BaseDocumentStore):
+class DatesetDocumentStore:
     def __init__(
     def __init__(
         self,
         self,
         dataset: Dataset,
         dataset: Dataset,
@@ -48,7 +44,7 @@ class DatesetDocumentStore(BaseDocumentStore):
         return self._embedding_model_name
         return self._embedding_model_name
 
 
     @property
     @property
-    def docs(self) -> Dict[str, BaseDocument]:
+    def docs(self) -> Dict[str, Document]:
         document_segments = db.session.query(DocumentSegment).filter(
         document_segments = db.session.query(DocumentSegment).filter(
             DocumentSegment.dataset_id == self._dataset.id
             DocumentSegment.dataset_id == self._dataset.id
         ).all()
         ).all()
@@ -56,13 +52,20 @@ class DatesetDocumentStore(BaseDocumentStore):
         output = {}
         output = {}
         for document_segment in document_segments:
         for document_segment in document_segments:
             doc_id = document_segment.index_node_id
             doc_id = document_segment.index_node_id
-            result = self.segment_to_dict(document_segment)
-            output[doc_id] = json_to_doc(result)
+            output[doc_id] = Document(
+                page_content=document_segment.content,
+                metadata={
+                    "doc_id": document_segment.index_node_id,
+                    "doc_hash": document_segment.index_node_hash,
+                    "document_id": document_segment.document_id,
+                    "dataset_id": document_segment.dataset_id,
+                }
+            )
 
 
         return output
         return output
 
 
     def add_documents(
     def add_documents(
-        self, docs: Sequence[BaseDocument], allow_update: bool = True
+        self, docs: Sequence[Document], allow_update: bool = True
     ) -> None:
     ) -> None:
         max_position = db.session.query(func.max(DocumentSegment.position)).filter(
         max_position = db.session.query(func.max(DocumentSegment.position)).filter(
             DocumentSegment.document == self._document_id
             DocumentSegment.document == self._document_id
@@ -72,23 +75,20 @@ class DatesetDocumentStore(BaseDocumentStore):
             max_position = 0
             max_position = 0
 
 
         for doc in docs:
         for doc in docs:
-            if doc.is_doc_id_none:
-                raise ValueError("doc_id not set")
+            if not isinstance(doc, Document):
+                raise ValueError("doc must be a Document")
 
 
-            if not isinstance(doc, Node):
-                raise ValueError("doc must be a Node")
-
-            segment_document = self.get_document(doc_id=doc.get_doc_id(), raise_error=False)
+            segment_document = self.get_document(doc_id=doc.metadata['doc_id'], raise_error=False)
 
 
             # NOTE: doc could already exist in the store, but we overwrite it
             # NOTE: doc could already exist in the store, but we overwrite it
             if not allow_update and segment_document:
             if not allow_update and segment_document:
                 raise ValueError(
                 raise ValueError(
-                    f"doc_id {doc.get_doc_id()} already exists. "
+                    f"doc_id {doc.metadata['doc_id']} already exists. "
                     "Set allow_update to True to overwrite."
                     "Set allow_update to True to overwrite."
                 )
                 )
 
 
             # calc embedding use tokens
             # calc embedding use tokens
-            tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.get_text())
+            tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.page_content)
 
 
             if not segment_document:
             if not segment_document:
                 max_position += 1
                 max_position += 1
@@ -97,19 +97,19 @@ class DatesetDocumentStore(BaseDocumentStore):
                     tenant_id=self._dataset.tenant_id,
                     tenant_id=self._dataset.tenant_id,
                     dataset_id=self._dataset.id,
                     dataset_id=self._dataset.id,
                     document_id=self._document_id,
                     document_id=self._document_id,
-                    index_node_id=doc.get_doc_id(),
-                    index_node_hash=doc.get_doc_hash(),
+                    index_node_id=doc.metadata['doc_id'],
+                    index_node_hash=doc.metadata['doc_hash'],
                     position=max_position,
                     position=max_position,
-                    content=doc.get_text(),
-                    word_count=len(doc.get_text()),
+                    content=doc.page_content,
+                    word_count=len(doc.page_content),
                     tokens=tokens,
                     tokens=tokens,
                     created_by=self._user_id,
                     created_by=self._user_id,
                 )
                 )
                 db.session.add(segment_document)
                 db.session.add(segment_document)
             else:
             else:
-                segment_document.content = doc.get_text()
-                segment_document.index_node_hash = doc.get_doc_hash()
-                segment_document.word_count = len(doc.get_text())
+                segment_document.content = doc.page_content
+                segment_document.index_node_hash = doc.metadata['doc_hash']
+                segment_document.word_count = len(doc.page_content)
                 segment_document.tokens = tokens
                 segment_document.tokens = tokens
 
 
             db.session.commit()
             db.session.commit()
@@ -121,7 +121,7 @@ class DatesetDocumentStore(BaseDocumentStore):
 
 
     def get_document(
     def get_document(
         self, doc_id: str, raise_error: bool = True
         self, doc_id: str, raise_error: bool = True
-    ) -> Optional[BaseDocument]:
+    ) -> Optional[Document]:
         document_segment = self.get_document_segment(doc_id)
         document_segment = self.get_document_segment(doc_id)
 
 
         if document_segment is None:
         if document_segment is None:
@@ -130,8 +130,15 @@ class DatesetDocumentStore(BaseDocumentStore):
             else:
             else:
                 return None
                 return None
 
 
-        result = self.segment_to_dict(document_segment)
-        return json_to_doc(result)
+        return Document(
+            page_content=document_segment.content,
+            metadata={
+                "doc_id": document_segment.index_node_id,
+                "doc_hash": document_segment.index_node_hash,
+                "document_id": document_segment.document_id,
+                "dataset_id": document_segment.dataset_id,
+            }
+        )
 
 
     def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
     def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
         document_segment = self.get_document_segment(doc_id)
         document_segment = self.get_document_segment(doc_id)
@@ -164,15 +171,6 @@ class DatesetDocumentStore(BaseDocumentStore):
 
 
         return document_segment.index_node_hash
         return document_segment.index_node_hash
 
 
-    def update_docstore(self, other: "BaseDocumentStore") -> None:
-        """Update docstore.
-
-        Args:
-            other (BaseDocumentStore): docstore to update from
-
-        """
-        self.add_documents(list(other.docs.values()))
-
     def get_document_segment(self, doc_id: str) -> DocumentSegment:
     def get_document_segment(self, doc_id: str) -> DocumentSegment:
         document_segment = db.session.query(DocumentSegment).filter(
         document_segment = db.session.query(DocumentSegment).filter(
             DocumentSegment.dataset_id == self._dataset.id,
             DocumentSegment.dataset_id == self._dataset.id,
@@ -180,11 +178,3 @@ class DatesetDocumentStore(BaseDocumentStore):
         ).first()
         ).first()
 
 
         return document_segment
         return document_segment
-
-    def segment_to_dict(self, segment: DocumentSegment) -> Dict[str, Any]:
-        return {
-            "doc_id": segment.index_node_id,
-            "doc_hash": segment.index_node_hash,
-            "text": segment.content,
-            "__type__": Node.get_type()
-        }

+ 0 - 51
api/core/docstore/empty_docstore.py

@@ -1,51 +0,0 @@
-from typing import Any, Dict, Optional, Sequence
-from llama_index.docstore.types import BaseDocumentStore
-from llama_index.schema import BaseDocument
-
-
-class EmptyDocumentStore(BaseDocumentStore):
-    @classmethod
-    def from_dict(cls, config_dict: Dict[str, Any]) -> "EmptyDocumentStore":
-        return cls()
-
-    def to_dict(self) -> Dict[str, Any]:
-        """Serialize to dict."""
-        return {}
-
-    @property
-    def docs(self) -> Dict[str, BaseDocument]:
-        return {}
-
-    def add_documents(
-        self, docs: Sequence[BaseDocument], allow_update: bool = True
-    ) -> None:
-        pass
-
-    def document_exists(self, doc_id: str) -> bool:
-        """Check if document exists."""
-        return False
-
-    def get_document(
-        self, doc_id: str, raise_error: bool = True
-    ) -> Optional[BaseDocument]:
-        return None
-
-    def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
-        pass
-
-    def set_document_hash(self, doc_id: str, doc_hash: str) -> None:
-        """Set the hash for a given doc_id."""
-        pass
-
-    def get_document_hash(self, doc_id: str) -> Optional[str]:
-        """Get the stored hash for a document, if it exists."""
-        return None
-
-    def update_docstore(self, other: "BaseDocumentStore") -> None:
-        """Update docstore.
-
-        Args:
-            other (BaseDocumentStore): docstore to update from
-
-        """
-        self.add_documents(list(other.docs.values()))

+ 72 - 0
api/core/embedding/cached_embedding.py

@@ -0,0 +1,72 @@
+import logging
+from typing import List
+
+from langchain.embeddings.base import Embeddings
+from sqlalchemy.exc import IntegrityError
+
+from extensions.ext_database import db
+from libs import helper
+from models.dataset import Embedding
+
+
+class CacheEmbedding(Embeddings):
+    def __init__(self, embeddings: Embeddings):
+        self._embeddings = embeddings
+
+    def embed_documents(self, texts: List[str]) -> List[List[float]]:
+        """Embed search docs."""
+        # use doc embedding cache or store if not exists
+        text_embeddings = []
+        embedding_queue_texts = []
+        for text in texts:
+            hash = helper.generate_text_hash(text)
+            embedding = db.session.query(Embedding).filter_by(hash=hash).first()
+            if embedding:
+                text_embeddings.append(embedding.get_embedding())
+            else:
+                embedding_queue_texts.append(text)
+
+        embedding_results = self._embeddings.embed_documents(embedding_queue_texts)
+
+        i = 0
+        for text in embedding_queue_texts:
+            hash = helper.generate_text_hash(text)
+
+            try:
+                embedding = Embedding(hash=hash)
+                embedding.set_embedding(embedding_results[i])
+                db.session.add(embedding)
+                db.session.commit()
+            except IntegrityError:
+                db.session.rollback()
+                continue
+            except:
+                logging.exception('Failed to add embedding to db')
+                continue
+
+            i += 1
+
+        text_embeddings.extend(embedding_results)
+        return text_embeddings
+
+    def embed_query(self, text: str) -> List[float]:
+        """Embed query text."""
+        # use doc embedding cache or store if not exists
+        hash = helper.generate_text_hash(text)
+        embedding = db.session.query(Embedding).filter_by(hash=hash).first()
+        if embedding:
+            return embedding.get_embedding()
+
+        embedding_results = self._embeddings.embed_query(text)
+
+        try:
+            embedding = Embedding(hash=hash)
+            embedding.set_embedding(embedding_results)
+            db.session.add(embedding)
+            db.session.commit()
+        except IntegrityError:
+            db.session.rollback()
+        except:
+            logging.exception('Failed to add embedding to db')
+
+        return embedding_results

+ 0 - 214
api/core/embedding/openai_embedding.py

@@ -1,214 +0,0 @@
-from typing import Optional, Any, List
-
-import openai
-from llama_index.embeddings.base import BaseEmbedding
-from llama_index.embeddings.openai import OpenAIEmbeddingMode, OpenAIEmbeddingModelType, _QUERY_MODE_MODEL_DICT, \
-    _TEXT_MODE_MODEL_DICT
-from tenacity import wait_random_exponential, retry, stop_after_attempt
-
-from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
-
-
-@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
-def get_embedding(
-        text: str,
-        engine: Optional[str] = None,
-        api_key: Optional[str] = None,
-        **kwargs
-) -> List[float]:
-    """Get embedding.
-
-    NOTE: Copied from OpenAI's embedding utils:
-    https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
-
-    Copied here to avoid importing unnecessary dependencies
-    like matplotlib, plotly, scipy, sklearn.
-
-    """
-    text = text.replace("\n", " ")
-    return openai.Embedding.create(input=[text], engine=engine, api_key=api_key, **kwargs)["data"][0]["embedding"]
-
-
-@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
-async def aget_embedding(text: str, engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs) -> List[
-    float]:
-    """Asynchronously get embedding.
-
-    NOTE: Copied from OpenAI's embedding utils:
-    https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
-
-    Copied here to avoid importing unnecessary dependencies
-    like matplotlib, plotly, scipy, sklearn.
-
-    """
-    # replace newlines, which can negatively affect performance.
-    text = text.replace("\n", " ")
-
-    return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=api_key, **kwargs))["data"][0][
-        "embedding"
-    ]
-
-
-@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
-def get_embeddings(
-        list_of_text: List[str],
-        engine: Optional[str] = None,
-        api_key: Optional[str] = None,
-        **kwargs
-) -> List[List[float]]:
-    """Get embeddings.
-
-    NOTE: Copied from OpenAI's embedding utils:
-    https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
-
-    Copied here to avoid importing unnecessary dependencies
-    like matplotlib, plotly, scipy, sklearn.
-
-    """
-    assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
-
-    # replace newlines, which can negatively affect performance.
-    list_of_text = [text.replace("\n", " ") for text in list_of_text]
-
-    data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=api_key, **kwargs).data
-    data = sorted(data, key=lambda x: x["index"])  # maintain the same order as input.
-    return [d["embedding"] for d in data]
-
-
-@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
-async def aget_embeddings(
-        list_of_text: List[str], engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs
-) -> List[List[float]]:
-    """Asynchronously get embeddings.
-
-    NOTE: Copied from OpenAI's embedding utils:
-    https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
-
-    Copied here to avoid importing unnecessary dependencies
-    like matplotlib, plotly, scipy, sklearn.
-
-    """
-    assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
-
-    # replace newlines, which can negatively affect performance.
-    list_of_text = [text.replace("\n", " ") for text in list_of_text]
-
-    data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=api_key, **kwargs)).data
-    data = sorted(data, key=lambda x: x["index"])  # maintain the same order as input.
-    return [d["embedding"] for d in data]
-
-
-class OpenAIEmbedding(BaseEmbedding):
-
-    def __init__(
-            self,
-            mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
-            model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,
-            deployment_name: Optional[str] = None,
-            openai_api_key: Optional[str] = None,
-            **kwargs: Any,
-    ) -> None:
-        """Init params."""
-        new_kwargs = {}
-
-        if 'embed_batch_size' in kwargs:
-            new_kwargs['embed_batch_size'] = kwargs['embed_batch_size']
-
-        if 'tokenizer' in kwargs:
-            new_kwargs['tokenizer'] = kwargs['tokenizer']
-
-        super().__init__(**new_kwargs)
-        self.mode = OpenAIEmbeddingMode(mode)
-        self.model = OpenAIEmbeddingModelType(model)
-        self.deployment_name = deployment_name
-        self.openai_api_key = openai_api_key
-        self.openai_api_type = kwargs.get('openai_api_type')
-        self.openai_api_version = kwargs.get('openai_api_version')
-        self.openai_api_base = kwargs.get('openai_api_base')
-
-    @handle_llm_exceptions
-    def _get_query_embedding(self, query: str) -> List[float]:
-        """Get query embedding."""
-        if self.deployment_name is not None:
-            engine = self.deployment_name
-        else:
-            key = (self.mode, self.model)
-            if key not in _QUERY_MODE_MODEL_DICT:
-                raise ValueError(f"Invalid mode, model combination: {key}")
-            engine = _QUERY_MODE_MODEL_DICT[key]
-        return get_embedding(query, engine=engine, api_key=self.openai_api_key,
-                             api_type=self.openai_api_type, api_version=self.openai_api_version,
-                             api_base=self.openai_api_base)
-
-    def _get_text_embedding(self, text: str) -> List[float]:
-        """Get text embedding."""
-        if self.deployment_name is not None:
-            engine = self.deployment_name
-        else:
-            key = (self.mode, self.model)
-            if key not in _TEXT_MODE_MODEL_DICT:
-                raise ValueError(f"Invalid mode, model combination: {key}")
-            engine = _TEXT_MODE_MODEL_DICT[key]
-        return get_embedding(text, engine=engine, api_key=self.openai_api_key,
-                             api_type=self.openai_api_type, api_version=self.openai_api_version,
-                             api_base=self.openai_api_base)
-
-    async def _aget_text_embedding(self, text: str) -> List[float]:
-        """Asynchronously get text embedding."""
-        if self.deployment_name is not None:
-            engine = self.deployment_name
-        else:
-            key = (self.mode, self.model)
-            if key not in _TEXT_MODE_MODEL_DICT:
-                raise ValueError(f"Invalid mode, model combination: {key}")
-            engine = _TEXT_MODE_MODEL_DICT[key]
-        return await aget_embedding(text, engine=engine, api_key=self.openai_api_key,
-                                    api_type=self.openai_api_type, api_version=self.openai_api_version,
-                                    api_base=self.openai_api_base)
-
-    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
-        """Get text embeddings.
-
-        By default, this is a wrapper around _get_text_embedding.
-        Can be overriden for batch queries.
-
-        """
-        if self.openai_api_type and self.openai_api_type == 'azure':
-            embeddings = []
-            for text in texts:
-                embeddings.append(self._get_text_embedding(text))
-
-            return embeddings
-
-        if self.deployment_name is not None:
-            engine = self.deployment_name
-        else:
-            key = (self.mode, self.model)
-            if key not in _TEXT_MODE_MODEL_DICT:
-                raise ValueError(f"Invalid mode, model combination: {key}")
-            engine = _TEXT_MODE_MODEL_DICT[key]
-        embeddings = get_embeddings(texts, engine=engine, api_key=self.openai_api_key,
-                                    api_type=self.openai_api_type, api_version=self.openai_api_version,
-                                    api_base=self.openai_api_base)
-        return embeddings
-
-    async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
-        """Asynchronously get text embeddings."""
-        if self.openai_api_type and self.openai_api_type == 'azure':
-            embeddings = []
-            for text in texts:
-                embeddings.append(await self._aget_text_embedding(text))
-
-            return embeddings
-
-        if self.deployment_name is not None:
-            engine = self.deployment_name
-        else:
-            key = (self.mode, self.model)
-            if key not in _TEXT_MODE_MODEL_DICT:
-                raise ValueError(f"Invalid mode, model combination: {key}")
-            engine = _TEXT_MODE_MODEL_DICT[key]
-        embeddings = await aget_embeddings(texts, engine=engine, api_key=self.openai_api_key,
-                                           api_type=self.openai_api_type, api_version=self.openai_api_version,
-                                           api_base=self.openai_api_base)
-        return embeddings

+ 59 - 0
api/core/index/base.py

@@ -0,0 +1,59 @@
+from __future__ import annotations
+from abc import abstractmethod, ABC
+from typing import List, Any
+
+from langchain.schema import Document, BaseRetriever
+
+from models.dataset import Dataset
+
+
+class BaseIndex(ABC):
+
+    def __init__(self, dataset: Dataset):
+        self.dataset = dataset
+
+    @abstractmethod
+    def create(self, texts: list[Document], **kwargs) -> BaseIndex:
+        raise NotImplementedError
+
+    @abstractmethod
+    def add_texts(self, texts: list[Document], **kwargs):
+        raise NotImplementedError
+
+    @abstractmethod
+    def text_exists(self, id: str) -> bool:
+        raise NotImplementedError
+
+    @abstractmethod
+    def delete_by_ids(self, ids: list[str]) -> None:
+        raise NotImplementedError
+
+    @abstractmethod
+    def delete_by_document_id(self, document_id: str):
+        raise NotImplementedError
+
+    @abstractmethod
+    def get_retriever(self, **kwargs: Any) -> BaseRetriever:
+        raise NotImplementedError
+
+    @abstractmethod
+    def search(
+            self, query: str,
+            **kwargs: Any
+    ) -> List[Document]:
+        raise NotImplementedError
+
+    def delete(self) -> None:
+        raise NotImplementedError
+
+    def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
+        for text in texts:
+            doc_id = text.metadata['doc_id']
+            exists_duplicate_node = self.text_exists(doc_id)
+            if exists_duplicate_node:
+                texts.remove(text)
+
+        return texts
+
+    def _get_uuids(self, texts: list[Document]) -> list[str]:
+        return [text.metadata['doc_id'] for text in texts]

+ 41 - 0
api/core/index/index.py

@@ -0,0 +1,41 @@
+from flask import current_app
+from langchain.embeddings import OpenAIEmbeddings
+
+from core.embedding.cached_embedding import CacheEmbedding
+from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
+from core.index.vector_index.vector_index import VectorIndex
+from core.llm.llm_builder import LLMBuilder
+from models.dataset import Dataset
+
+
+class IndexBuilder:
+    @classmethod
+    def get_index(cls, dataset: Dataset, indexing_technique: str, ignore_high_quality_check: bool = False):
+        if indexing_technique == "high_quality":
+            if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
+                return None
+
+            model_credentials = LLMBuilder.get_model_credentials(
+                tenant_id=dataset.tenant_id,
+                model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
+                model_name='text-embedding-ada-002'
+            )
+
+            embeddings = CacheEmbedding(OpenAIEmbeddings(
+                **model_credentials
+            ))
+
+            return VectorIndex(
+                dataset=dataset,
+                config=current_app.config,
+                embeddings=embeddings
+            )
+        elif indexing_technique == "economy":
+            return KeywordTableIndex(
+                dataset=dataset,
+                config=KeywordTableConfig(
+                    max_keywords_per_chunk=10
+                )
+            )
+        else:
+            raise ValueError('Unknown indexing technique')

+ 0 - 60
api/core/index/index_builder.py

@@ -1,60 +0,0 @@
-from langchain.callbacks import CallbackManager
-from llama_index import ServiceContext, PromptHelper, LLMPredictor
-from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
-from core.embedding.openai_embedding import OpenAIEmbedding
-from core.llm.llm_builder import LLMBuilder
-
-
-class IndexBuilder:
-    @classmethod
-    def get_default_service_context(cls, tenant_id: str) -> ServiceContext:
-        # set number of output tokens
-        num_output = 512
-
-        # only for verbose
-        callback_manager = CallbackManager([DifyStdOutCallbackHandler()])
-
-        llm = LLMBuilder.to_llm(
-            tenant_id=tenant_id,
-            model_name='text-davinci-003',
-            temperature=0,
-            max_tokens=num_output,
-            callback_manager=callback_manager,
-        )
-
-        llm_predictor = LLMPredictor(llm=llm)
-
-        # These parameters here will affect the logic of segmenting the final synthesized response.
-        # The number of refinement iterations in the synthesis process depends
-        # on whether the length of the segmented output exceeds the max_input_size.
-        prompt_helper = PromptHelper(
-            max_input_size=3500,
-            num_output=num_output,
-            max_chunk_overlap=20
-        )
-
-        provider = LLMBuilder.get_default_provider(tenant_id)
-
-        model_credentials = LLMBuilder.get_model_credentials(
-            tenant_id=tenant_id,
-            model_provider=provider,
-            model_name='text-embedding-ada-002'
-        )
-
-        return ServiceContext.from_defaults(
-            llm_predictor=llm_predictor,
-            prompt_helper=prompt_helper,
-            embed_model=OpenAIEmbedding(**model_credentials),
-        )
-
-    @classmethod
-    def get_fake_llm_service_context(cls, tenant_id: str) -> ServiceContext:
-        llm = LLMBuilder.to_llm(
-            tenant_id=tenant_id,
-            model_name='fake'
-        )
-
-        return ServiceContext.from_defaults(
-            llm_predictor=LLMPredictor(llm=llm),
-            embed_model=OpenAIEmbedding()
-        )

+ 0 - 159
api/core/index/keyword_table/jieba_keyword_table.py

@@ -1,159 +0,0 @@
-import re
-from typing import (
-    Any,
-    Dict,
-    List,
-    Set,
-    Optional
-)
-
-import jieba.analyse
-
-from core.index.keyword_table.stopwords import STOPWORDS
-from llama_index.indices.query.base import IS
-from llama_index import QueryMode
-from llama_index.indices.base import QueryMap
-from llama_index.indices.keyword_table.base import BaseGPTKeywordTableIndex
-from llama_index.indices.keyword_table.query import BaseGPTKeywordTableQuery
-from llama_index.docstore import BaseDocumentStore
-from llama_index.indices.postprocessor.node import (
-    BaseNodePostprocessor,
-)
-from llama_index.indices.response.response_builder import ResponseMode
-from llama_index.indices.service_context import ServiceContext
-from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
-from llama_index.prompts.prompts import (
-    QuestionAnswerPrompt,
-    RefinePrompt,
-    SimpleInputPrompt,
-)
-
-from core.index.query.synthesizer import EnhanceResponseSynthesizer
-
-
-def jieba_extract_keywords(
-        text_chunk: str,
-        max_keywords: Optional[int] = None,
-        expand_with_subtokens: bool = True,
-) -> Set[str]:
-    """Extract keywords with JIEBA tfidf."""
-    keywords = jieba.analyse.extract_tags(
-        sentence=text_chunk,
-        topK=max_keywords,
-    )
-
-    if expand_with_subtokens:
-        return set(expand_tokens_with_subtokens(keywords))
-    else:
-        return set(keywords)
-
-
-def expand_tokens_with_subtokens(tokens: Set[str]) -> Set[str]:
-    """Get subtokens from a list of tokens., filtering for stopwords."""
-    results = set()
-    for token in tokens:
-        results.add(token)
-        sub_tokens = re.findall(r"\w+", token)
-        if len(sub_tokens) > 1:
-            results.update({w for w in sub_tokens if w not in list(STOPWORDS)})
-
-    return results
-
-
-class GPTJIEBAKeywordTableIndex(BaseGPTKeywordTableIndex):
-    """GPT JIEBA Keyword Table Index.
-
-    This index uses a JIEBA keyword extractor to extract keywords from the text.
-
-    """
-
-    def _extract_keywords(self, text: str) -> Set[str]:
-        """Extract keywords from text."""
-        return jieba_extract_keywords(text, max_keywords=self.max_keywords_per_chunk)
-
-    @classmethod
-    def get_query_map(self) -> QueryMap:
-        """Get query map."""
-        super_map = super().get_query_map()
-        super_map[QueryMode.DEFAULT] = GPTKeywordTableJIEBAQuery
-        return super_map
-
-    def _delete(self, doc_id: str, **delete_kwargs: Any) -> None:
-        """Delete a document."""
-        # get set of ids that correspond to node
-        node_idxs_to_delete = {doc_id}
-
-        # delete node_idxs from keyword to node idxs mapping
-        keywords_to_delete = set()
-        for keyword, node_idxs in self._index_struct.table.items():
-            if node_idxs_to_delete.intersection(node_idxs):
-                self._index_struct.table[keyword] = node_idxs.difference(
-                    node_idxs_to_delete
-                )
-                if not self._index_struct.table[keyword]:
-                    keywords_to_delete.add(keyword)
-
-        for keyword in keywords_to_delete:
-            del self._index_struct.table[keyword]
-
-
-class GPTKeywordTableJIEBAQuery(BaseGPTKeywordTableQuery):
-    """GPT Keyword Table Index JIEBA Query.
-
-    Extracts keywords using JIEBA keyword extractor.
-    Set when `mode="jieba"` in `query` method of `GPTKeywordTableIndex`.
-
-    .. code-block:: python
-
-        response = index.query("<query_str>", mode="jieba")
-
-    See BaseGPTKeywordTableQuery for arguments.
-
-    """
-
-    @classmethod
-    def from_args(
-            cls,
-            index_struct: IS,
-            service_context: ServiceContext,
-            docstore: Optional[BaseDocumentStore] = None,
-            node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
-            verbose: bool = False,
-            # response synthesizer args
-            response_mode: ResponseMode = ResponseMode.DEFAULT,
-            text_qa_template: Optional[QuestionAnswerPrompt] = None,
-            refine_template: Optional[RefinePrompt] = None,
-            simple_template: Optional[SimpleInputPrompt] = None,
-            response_kwargs: Optional[Dict] = None,
-            use_async: bool = False,
-            streaming: bool = False,
-            optimizer: Optional[BaseTokenUsageOptimizer] = None,
-            # class-specific args
-            **kwargs: Any,
-    ) -> "BaseGPTIndexQuery":
-        response_synthesizer = EnhanceResponseSynthesizer.from_args(
-            service_context=service_context,
-            text_qa_template=text_qa_template,
-            refine_template=refine_template,
-            simple_template=simple_template,
-            response_mode=response_mode,
-            response_kwargs=response_kwargs,
-            use_async=use_async,
-            streaming=streaming,
-            optimizer=optimizer,
-        )
-        return cls(
-            index_struct=index_struct,
-            service_context=service_context,
-            response_synthesizer=response_synthesizer,
-            docstore=docstore,
-            node_postprocessors=node_postprocessors,
-            verbose=verbose,
-            **kwargs,
-        )
-
-    def _get_keywords(self, query_str: str) -> List[str]:
-        """Extract keywords."""
-        return list(
-            jieba_extract_keywords(query_str, max_keywords=self.max_keywords_per_query)
-        )

+ 0 - 135
api/core/index/keyword_table_index.py

@@ -1,135 +0,0 @@
-import json
-from typing import List, Optional
-
-from llama_index import ServiceContext, LLMPredictor, OpenAIEmbedding
-from llama_index.data_structs import KeywordTable, Node
-from llama_index.indices.keyword_table.base import BaseGPTKeywordTableIndex
-from llama_index.indices.registry import load_index_struct_from_dict
-
-from core.docstore.dataset_docstore import DatesetDocumentStore
-from core.docstore.empty_docstore import EmptyDocumentStore
-from core.index.index_builder import IndexBuilder
-from core.index.keyword_table.jieba_keyword_table import GPTJIEBAKeywordTableIndex
-from core.llm.llm_builder import LLMBuilder
-from extensions.ext_database import db
-from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment
-
-
-class KeywordTableIndex:
-
-    def __init__(self, dataset: Dataset):
-        self._dataset = dataset
-
-    def add_nodes(self, nodes: List[Node]):
-        llm = LLMBuilder.to_llm(
-            tenant_id=self._dataset.tenant_id,
-            model_name='fake'
-        )
-
-        service_context = ServiceContext.from_defaults(
-            llm_predictor=LLMPredictor(llm=llm),
-            embed_model=OpenAIEmbedding()
-        )
-
-        dataset_keyword_table = self.get_keyword_table()
-        if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
-            index_struct = KeywordTable()
-        else:
-            index_struct_dict = dataset_keyword_table.keyword_table_dict
-            index_struct: KeywordTable = load_index_struct_from_dict(index_struct_dict)
-
-        # create index
-        index = GPTJIEBAKeywordTableIndex(
-            index_struct=index_struct,
-            docstore=EmptyDocumentStore(),
-            service_context=service_context
-        )
-
-        for node in nodes:
-            keywords = index._extract_keywords(node.get_text())
-            self.update_segment_keywords(node.doc_id, list(keywords))
-            index._index_struct.add_node(list(keywords), node)
-
-        index_struct_dict = index.index_struct.to_dict()
-
-        if not dataset_keyword_table:
-            dataset_keyword_table = DatasetKeywordTable(
-                dataset_id=self._dataset.id,
-                keyword_table=json.dumps(index_struct_dict)
-            )
-            db.session.add(dataset_keyword_table)
-        else:
-            dataset_keyword_table.keyword_table = json.dumps(index_struct_dict)
-
-        db.session.commit()
-
-    def del_nodes(self, node_ids: List[str]):
-        llm = LLMBuilder.to_llm(
-            tenant_id=self._dataset.tenant_id,
-            model_name='fake'
-        )
-
-        service_context = ServiceContext.from_defaults(
-            llm_predictor=LLMPredictor(llm=llm),
-            embed_model=OpenAIEmbedding()
-        )
-
-        dataset_keyword_table = self.get_keyword_table()
-        if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
-            return
-        else:
-            index_struct_dict = dataset_keyword_table.keyword_table_dict
-            index_struct: KeywordTable = load_index_struct_from_dict(index_struct_dict)
-
-        # create index
-        index = GPTJIEBAKeywordTableIndex(
-            index_struct=index_struct,
-            docstore=EmptyDocumentStore(),
-            service_context=service_context
-        )
-
-        for node_id in node_ids:
-            index.delete(node_id)
-
-        index_struct_dict = index.index_struct.to_dict()
-
-        if not dataset_keyword_table:
-            dataset_keyword_table = DatasetKeywordTable(
-                dataset_id=self._dataset.id,
-                keyword_table=json.dumps(index_struct_dict)
-            )
-            db.session.add(dataset_keyword_table)
-        else:
-            dataset_keyword_table.keyword_table = json.dumps(index_struct_dict)
-
-        db.session.commit()
-
-    @property
-    def query_index(self) -> Optional[BaseGPTKeywordTableIndex]:
-        docstore = DatesetDocumentStore(
-            dataset=self._dataset,
-            user_id=self._dataset.created_by,
-            embedding_model_name="text-embedding-ada-002"
-        )
-
-        service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
-
-        dataset_keyword_table = self.get_keyword_table()
-        if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
-            return None
-
-        index_struct: KeywordTable = load_index_struct_from_dict(dataset_keyword_table.keyword_table_dict)
-
-        return GPTJIEBAKeywordTableIndex(index_struct=index_struct, docstore=docstore, service_context=service_context)
-
-    def get_keyword_table(self):
-        dataset_keyword_table = self._dataset.dataset_keyword_table
-        if dataset_keyword_table:
-            return dataset_keyword_table
-        return None
-
-    def update_segment_keywords(self, node_id: str, keywords: List[str]):
-        document_segment = db.session.query(DocumentSegment).filter(DocumentSegment.index_node_id == node_id).first()
-        if document_segment:
-            document_segment.keywords = keywords
-            db.session.commit()

+ 33 - 0
api/core/index/keyword_table_index/jieba_keyword_table_handler.py

@@ -0,0 +1,33 @@
+import re
+from typing import Set
+
+import jieba
+from jieba.analyse import default_tfidf
+
+from core.index.keyword_table_index.stopwords import STOPWORDS
+
+
+class JiebaKeywordTableHandler:
+
+    def __init__(self):
+        default_tfidf.stop_words = STOPWORDS
+
+    def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> Set[str]:
+        """Extract keywords with JIEBA tfidf."""
+        keywords = jieba.analyse.extract_tags(
+            sentence=text,
+            topK=max_keywords_per_chunk,
+        )
+
+        return set(self._expand_tokens_with_subtokens(keywords))
+
+    def _expand_tokens_with_subtokens(self, tokens: Set[str]) -> Set[str]:
+        """Get subtokens from a list of tokens., filtering for stopwords."""
+        results = set()
+        for token in tokens:
+            results.add(token)
+            sub_tokens = re.findall(r"\w+", token)
+            if len(sub_tokens) > 1:
+                results.update({w for w in sub_tokens if w not in list(STOPWORDS)})
+
+        return results

+ 238 - 0
api/core/index/keyword_table_index/keyword_table_index.py

@@ -0,0 +1,238 @@
+import json
+from collections import defaultdict
+from typing import Any, List, Optional, Dict
+
+from langchain.schema import Document, BaseRetriever
+from pydantic import BaseModel, Field, Extra
+
+from core.index.base import BaseIndex
+from core.index.keyword_table_index.jieba_keyword_table_handler import JiebaKeywordTableHandler
+from extensions.ext_database import db
+from models.dataset import Dataset, DocumentSegment, DatasetKeywordTable
+
+
+class KeywordTableConfig(BaseModel):
+    max_keywords_per_chunk: int = 10
+
+
+class KeywordTableIndex(BaseIndex):
+    def __init__(self, dataset: Dataset, config: KeywordTableConfig = KeywordTableConfig()):
+        super().__init__(dataset)
+        self._config = config
+
+    def create(self, texts: list[Document], **kwargs) -> BaseIndex:
+        keyword_table_handler = JiebaKeywordTableHandler()
+        keyword_table = {}
+        for text in texts:
+            keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
+            self._update_segment_keywords(text.metadata['doc_id'], list(keywords))
+            keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
+
+        dataset_keyword_table = DatasetKeywordTable(
+            dataset_id=self.dataset.id,
+            keyword_table=json.dumps({
+                '__type__': 'keyword_table',
+                '__data__': {
+                    "index_id": self.dataset.id,
+                    "summary": None,
+                    "table": {}
+                }
+            }, cls=SetEncoder)
+        )
+        db.session.add(dataset_keyword_table)
+        db.session.commit()
+
+        self._save_dataset_keyword_table(keyword_table)
+
+        return self
+
+    def add_texts(self, texts: list[Document], **kwargs):
+        keyword_table_handler = JiebaKeywordTableHandler()
+
+        keyword_table = self._get_dataset_keyword_table()
+        for text in texts:
+            keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
+            self._update_segment_keywords(text.metadata['doc_id'], list(keywords))
+            keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
+
+        self._save_dataset_keyword_table(keyword_table)
+
+    def text_exists(self, id: str) -> bool:
+        keyword_table = self._get_dataset_keyword_table()
+        return id in set.union(*keyword_table.values())
+
+    def delete_by_ids(self, ids: list[str]) -> None:
+        keyword_table = self._get_dataset_keyword_table()
+        keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
+
+        self._save_dataset_keyword_table(keyword_table)
+
+    def delete_by_document_id(self, document_id: str):
+        # get segment ids by document_id
+        segments = db.session.query(DocumentSegment).filter(
+            DocumentSegment.dataset_id == self.dataset.id,
+            DocumentSegment.document_id == document_id
+        ).all()
+
+        ids = [segment.id for segment in segments]
+
+        keyword_table = self._get_dataset_keyword_table()
+        keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
+
+        self._save_dataset_keyword_table(keyword_table)
+
+    def get_retriever(self, **kwargs: Any) -> BaseRetriever:
+        return KeywordTableRetriever(index=self, **kwargs)
+
+    def search(
+            self, query: str,
+            **kwargs: Any
+    ) -> List[Document]:
+        keyword_table = self._get_dataset_keyword_table()
+
+        search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
+        k = search_kwargs.get('k') if search_kwargs.get('k') else 4
+
+        sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k)
+
+        documents = []
+        for chunk_index in sorted_chunk_indices:
+            segment = db.session.query(DocumentSegment).filter(
+                DocumentSegment.dataset_id == self.dataset.id,
+                DocumentSegment.index_node_id == chunk_index
+            ).first()
+
+            if segment:
+                documents.append(Document(
+                    page_content=segment.content,
+                    metadata={
+                        "doc_id": chunk_index,
+                        "document_id": segment.document_id,
+                        "dataset_id": segment.dataset_id,
+                    }
+                ))
+
+        return documents
+
+    def delete(self) -> None:
+        dataset_keyword_table = self.dataset.dataset_keyword_table
+        if dataset_keyword_table:
+            db.session.delete(dataset_keyword_table)
+            db.session.commit()
+
+    def _save_dataset_keyword_table(self, keyword_table):
+        keyword_table_dict = {
+            '__type__': 'keyword_table',
+            '__data__': {
+                "index_id": self.dataset.id,
+                "summary": None,
+                "table": keyword_table
+            }
+        }
+        self.dataset.dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder)
+        db.session.commit()
+
+    def _get_dataset_keyword_table(self) -> Optional[dict]:
+        dataset_keyword_table = self.dataset.dataset_keyword_table
+        if dataset_keyword_table:
+            if dataset_keyword_table.keyword_table_dict:
+                return dataset_keyword_table.keyword_table_dict['__data__']['table']
+        else:
+            dataset_keyword_table = DatasetKeywordTable(
+                dataset_id=self.dataset.id,
+                keyword_table=json.dumps({
+                    '__type__': 'keyword_table',
+                    '__data__': {
+                        "index_id": self.dataset.id,
+                        "summary": None,
+                        "table": {}
+                    }
+                }, cls=SetEncoder)
+            )
+            db.session.add(dataset_keyword_table)
+            db.session.commit()
+
+        return {}
+
+    def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict:
+        for keyword in keywords:
+            if keyword not in keyword_table:
+                keyword_table[keyword] = set()
+            keyword_table[keyword].add(id)
+        return keyword_table
+
+    def _delete_ids_from_keyword_table(self, keyword_table: dict, ids: list[str]) -> dict:
+        # get set of ids that correspond to node
+        node_idxs_to_delete = set(ids)
+
+        # delete node_idxs from keyword to node idxs mapping
+        keywords_to_delete = set()
+        for keyword, node_idxs in keyword_table.items():
+            if node_idxs_to_delete.intersection(node_idxs):
+                keyword_table[keyword] = node_idxs.difference(
+                    node_idxs_to_delete
+                )
+                if not keyword_table[keyword]:
+                    keywords_to_delete.add(keyword)
+
+        for keyword in keywords_to_delete:
+            del keyword_table[keyword]
+
+        return keyword_table
+
+    def _retrieve_ids_by_query(self, keyword_table: dict, query: str, k: int = 4):
+        keyword_table_handler = JiebaKeywordTableHandler()
+        keywords = keyword_table_handler.extract_keywords(query)
+
+        # go through text chunks in order of most matching keywords
+        chunk_indices_count: Dict[str, int] = defaultdict(int)
+        keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())]
+        for keyword in keywords:
+            for node_id in keyword_table[keyword]:
+                chunk_indices_count[node_id] += 1
+
+        sorted_chunk_indices = sorted(
+            list(chunk_indices_count.keys()),
+            key=lambda x: chunk_indices_count[x],
+            reverse=True,
+        )
+
+        return sorted_chunk_indices[: k]
+
+    def _update_segment_keywords(self, node_id: str, keywords: List[str]):
+        document_segment = db.session.query(DocumentSegment).filter(DocumentSegment.index_node_id == node_id).first()
+        if document_segment:
+            document_segment.keywords = keywords
+            db.session.commit()
+
+
+class KeywordTableRetriever(BaseRetriever, BaseModel):
+    index: KeywordTableIndex
+    search_kwargs: dict = Field(default_factory=dict)
+
+    class Config:
+        """Configuration for this pydantic object."""
+
+        extra = Extra.forbid
+        arbitrary_types_allowed = True
+
+    def get_relevant_documents(self, query: str) -> List[Document]:
+        """Get documents relevant for a query.
+
+        Args:
+            query: string to find relevant documents for
+
+        Returns:
+            List of relevant documents
+        """
+        return self.index.search(query, **self.search_kwargs)
+
+    async def aget_relevant_documents(self, query: str) -> List[Document]:
+        raise NotImplementedError("KeywordTableRetriever does not support async")
+
+
+class SetEncoder(json.JSONEncoder):
+    def default(self, obj):
+        if isinstance(obj, set):
+            return list(obj)
+        return super().default(obj)

+ 0 - 0
api/core/index/keyword_table/stopwords.py → api/core/index/keyword_table_index/stopwords.py


+ 0 - 79
api/core/index/query/synthesizer.py

@@ -1,79 +0,0 @@
-from typing import (
-    Any,
-    Dict,
-    Optional, Sequence,
-)
-
-from llama_index.indices.response.response_synthesis import ResponseSynthesizer
-from llama_index.indices.response.response_builder import ResponseMode, BaseResponseBuilder, get_response_builder
-from llama_index.indices.service_context import ServiceContext
-from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
-from llama_index.prompts.prompts import (
-    QuestionAnswerPrompt,
-    RefinePrompt,
-    SimpleInputPrompt,
-)
-from llama_index.types import RESPONSE_TEXT_TYPE
-
-
-class EnhanceResponseSynthesizer(ResponseSynthesizer):
-    @classmethod
-    def from_args(
-            cls,
-            service_context: ServiceContext,
-            streaming: bool = False,
-            use_async: bool = False,
-            text_qa_template: Optional[QuestionAnswerPrompt] = None,
-            refine_template: Optional[RefinePrompt] = None,
-            simple_template: Optional[SimpleInputPrompt] = None,
-            response_mode: ResponseMode = ResponseMode.DEFAULT,
-            response_kwargs: Optional[Dict] = None,
-            optimizer: Optional[BaseTokenUsageOptimizer] = None,
-    ) -> "ResponseSynthesizer":
-        response_builder: Optional[BaseResponseBuilder] = None
-        if response_mode != ResponseMode.NO_TEXT:
-            if response_mode == 'no_synthesizer':
-                response_builder = NoSynthesizer(
-                    service_context=service_context,
-                    simple_template=simple_template,
-                    streaming=streaming,
-                )
-            else:
-                response_builder = get_response_builder(
-                    service_context,
-                    text_qa_template,
-                    refine_template,
-                    simple_template,
-                    response_mode,
-                    use_async=use_async,
-                    streaming=streaming,
-                )
-        return cls(response_builder, response_mode, response_kwargs, optimizer)
-
-
-class NoSynthesizer(BaseResponseBuilder):
-    def __init__(
-            self,
-            service_context: ServiceContext,
-            simple_template: Optional[SimpleInputPrompt] = None,
-            streaming: bool = False,
-    ) -> None:
-        super().__init__(service_context, streaming)
-
-    async def aget_response(
-            self,
-            query_str: str,
-            text_chunks: Sequence[str],
-            prev_response: Optional[str] = None,
-            **response_kwargs: Any,
-    ) -> RESPONSE_TEXT_TYPE:
-        return "\n".join(text_chunks)
-
-    def get_response(
-            self,
-            query_str: str,
-            text_chunks: Sequence[str],
-            prev_response: Optional[str] = None,
-            **response_kwargs: Any,
-    ) -> RESPONSE_TEXT_TYPE:
-        return "\n".join(text_chunks)

+ 0 - 22
api/core/index/readers/html_parser.py

@@ -1,22 +0,0 @@
-from pathlib import Path
-from typing import Dict
-
-from bs4 import BeautifulSoup
-from llama_index.readers.file.base_parser import BaseParser
-
-
-class HTMLParser(BaseParser):
-    """HTML parser."""
-
-    def _init_parser(self) -> Dict:
-        """Init parser."""
-        return {}
-
-    def parse_file(self, file: Path, errors: str = "ignore") -> str:
-        """Parse file."""
-        with open(file, "rb") as fp:
-            soup = BeautifulSoup(fp, 'html.parser')
-            text = soup.get_text()
-            text = text.strip() if text else ''
-
-        return text

+ 0 - 111
api/core/index/readers/markdown_parser.py

@@ -1,111 +0,0 @@
-"""Markdown parser.
-
-Contains parser for md files.
-
-"""
-import re
-from pathlib import Path
-from typing import Any, Dict, List, Optional, Tuple, Union, cast
-
-from llama_index.readers.file.base_parser import BaseParser
-
-
-class MarkdownParser(BaseParser):
-    """Markdown parser.
-
-    Extract text from markdown files.
-    Returns dictionary with keys as headers and values as the text between headers.
-
-    """
-
-    def __init__(
-        self,
-        *args: Any,
-        remove_hyperlinks: bool = True,
-        remove_images: bool = True,
-        **kwargs: Any,
-    ) -> None:
-        """Init params."""
-        super().__init__(*args, **kwargs)
-        self._remove_hyperlinks = remove_hyperlinks
-        self._remove_images = remove_images
-
-    def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]:
-        """Convert a markdown file to a dictionary.
-
-        The keys are the headers and the values are the text under each header.
-
-        """
-        markdown_tups: List[Tuple[Optional[str], str]] = []
-        lines = markdown_text.split("\n")
-
-        current_header = None
-        current_text = ""
-
-        for line in lines:
-            header_match = re.match(r"^#+\s", line)
-            if header_match:
-                if current_header is not None:
-                    markdown_tups.append((current_header, current_text))
-
-                current_header = line
-                current_text = ""
-            else:
-                current_text += line + "\n"
-        markdown_tups.append((current_header, current_text))
-
-        if current_header is not None:
-            # pass linting, assert keys are defined
-            markdown_tups = [
-                (re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value))
-                for key, value in markdown_tups
-            ]
-        else:
-            markdown_tups = [
-                (key, re.sub("\n", "", value)) for key, value in markdown_tups
-            ]
-
-        return markdown_tups
-
-    def remove_images(self, content: str) -> str:
-        """Get a dictionary of a markdown file from its path."""
-        pattern = r"!{1}\[\[(.*)\]\]"
-        content = re.sub(pattern, "", content)
-        return content
-
-    def remove_hyperlinks(self, content: str) -> str:
-        """Get a dictionary of a markdown file from its path."""
-        pattern = r"\[(.*?)\]\((.*?)\)"
-        content = re.sub(pattern, r"\1", content)
-        return content
-
-    def _init_parser(self) -> Dict:
-        """Initialize the parser with the config."""
-        return {}
-
-    def parse_tups(
-        self, filepath: Path, errors: str = "ignore"
-    ) -> List[Tuple[Optional[str], str]]:
-        """Parse file into tuples."""
-        with open(filepath, "r", encoding="utf-8") as f:
-            content = f.read()
-        if self._remove_hyperlinks:
-            content = self.remove_hyperlinks(content)
-        if self._remove_images:
-            content = self.remove_images(content)
-        markdown_tups = self.markdown_to_tups(content)
-        return markdown_tups
-
-    def parse_file(
-        self, filepath: Path, errors: str = "ignore"
-    ) -> Union[str, List[str]]:
-        """Parse file into string."""
-        tups = self.parse_tups(filepath, errors=errors)
-        results = []
-        # TODO: don't include headers right now
-        for header, value in tups:
-            if header is None:
-                results.append(value)
-            else:
-                results.append(f"\n\n{header}\n{value}")
-        return results

+ 0 - 56
api/core/index/readers/pdf_parser.py

@@ -1,56 +0,0 @@
-from pathlib import Path
-from typing import Dict
-
-from flask import current_app
-from llama_index.readers.file.base_parser import BaseParser
-from pypdf import PdfReader
-
-from extensions.ext_storage import storage
-from models.model import UploadFile
-
-
-class PDFParser(BaseParser):
-    """PDF parser."""
-
-    def _init_parser(self) -> Dict:
-        """Init parser."""
-        return {}
-
-    def parse_file(self, file: Path, errors: str = "ignore") -> str:
-        """Parse file."""
-        if not current_app.config.get('PDF_PREVIEW', True):
-            return ''
-
-        plaintext_file_key = ''
-        plaintext_file_exists = False
-        if self._parser_config and 'upload_file' in self._parser_config and self._parser_config['upload_file']:
-            upload_file: UploadFile = self._parser_config['upload_file']
-            if upload_file.hash:
-                plaintext_file_key = 'upload_files/' + upload_file.tenant_id + '/' + upload_file.hash + '.plaintext'
-                try:
-                    text = storage.load(plaintext_file_key).decode('utf-8')
-                    plaintext_file_exists = True
-                    return text
-                except FileNotFoundError:
-                    pass
-
-        text_list = []
-        with open(file, "rb") as fp:
-            # Create a PDF object
-            pdf = PdfReader(fp)
-
-            # Get the number of pages in the PDF document
-            num_pages = len(pdf.pages)
-
-            # Iterate over every page
-            for page in range(num_pages):
-                # Extract the text from the page
-                page_text = pdf.pages[page].extract_text()
-                text_list.append(page_text)
-        text = "\n".join(text_list)
-
-        # save plaintext file for caching
-        if not plaintext_file_exists and plaintext_file_key:
-            storage.save(plaintext_file_key, text.encode('utf-8'))
-
-        return text

+ 0 - 33
api/core/index/readers/xlsx_parser.py

@@ -1,33 +0,0 @@
-from pathlib import Path
-import json
-from typing import Dict
-from openpyxl import load_workbook
-
-from llama_index.readers.file.base_parser import BaseParser
-from flask import current_app
-
-
-class XLSXParser(BaseParser):
-    """XLSX parser."""
-
-    def _init_parser(self) -> Dict:
-        """Init parser"""
-        return {}
-
-    def parse_file(self, file: Path, errors: str = "ignore") -> str:
-        data = []
-        keys = []
-        with open(file, "r") as fp:
-            wb = load_workbook(filename=file, read_only=True)
-            # loop over all sheets
-            for sheet in wb:
-                for row in sheet.iter_rows(values_only=True):
-                    if all(v is None for v in row):
-                        continue
-                    if keys == []:
-                        keys = list(map(str, row))
-                    else:
-                        row_dict = dict(zip(keys, row))
-                        row_dict = {k: v for k, v in row_dict.items() if v}
-                        data.append(json.dumps(row_dict, ensure_ascii=False))
-        return '\n\n'.join(data)

+ 0 - 136
api/core/index/vector_index.py

@@ -1,136 +0,0 @@
-import json
-import logging
-from typing import List, Optional
-
-from llama_index.data_structs import Node
-from requests import ReadTimeout
-from sqlalchemy.exc import IntegrityError
-from tenacity import retry, stop_after_attempt, retry_if_exception_type
-
-from core.index.index_builder import IndexBuilder
-from core.vector_store.base import BaseGPTVectorStoreIndex
-from extensions.ext_vector_store import vector_store
-from extensions.ext_database import db
-from models.dataset import Dataset, Embedding
-
-
-class VectorIndex:
-
-    def __init__(self, dataset: Dataset):
-        self._dataset = dataset
-
-    def add_nodes(self, nodes: List[Node], duplicate_check: bool = False):
-        if not self._dataset.index_struct_dict:
-            index_id = "Vector_index_" + self._dataset.id.replace("-", "_")
-            self._dataset.index_struct = json.dumps(vector_store.to_index_struct(index_id))
-            db.session.commit()
-
-        service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
-
-        index = vector_store.get_index(
-            service_context=service_context,
-            index_struct=self._dataset.index_struct_dict
-        )
-
-        if duplicate_check:
-            nodes = self._filter_duplicate_nodes(index, nodes)
-
-        embedding_queue_nodes = []
-        embedded_nodes = []
-        for node in nodes:
-            node_hash = node.doc_hash
-
-            # if node hash in cached embedding tables, use cached embedding
-            embedding = db.session.query(Embedding).filter_by(hash=node_hash).first()
-            if embedding:
-                node.embedding = embedding.get_embedding()
-                embedded_nodes.append(node)
-            else:
-                embedding_queue_nodes.append(node)
-
-        if embedding_queue_nodes:
-            embedding_results = index._get_node_embedding_results(
-                embedding_queue_nodes,
-                set(),
-            )
-
-            # pre embed nodes for cached embedding
-            for embedding_result in embedding_results:
-                node = embedding_result.node
-                node.embedding = embedding_result.embedding
-
-                try:
-                    embedding = Embedding(hash=node.doc_hash)
-                    embedding.set_embedding(node.embedding)
-                    db.session.add(embedding)
-                    db.session.commit()
-                except IntegrityError:
-                    db.session.rollback()
-                    continue
-                except:
-                    logging.exception('Failed to add embedding to db')
-                    continue
-
-                embedded_nodes.append(node)
-
-        self.index_insert_nodes(index, embedded_nodes)
-
-    @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
-    def index_insert_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]):
-        index.insert_nodes(nodes)
-
-    def del_nodes(self, node_ids: List[str]):
-        if not self._dataset.index_struct_dict:
-            return
-
-        service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id)
-
-        index = vector_store.get_index(
-            service_context=service_context,
-            index_struct=self._dataset.index_struct_dict
-        )
-
-        for node_id in node_ids:
-            self.index_delete_node(index, node_id)
-
-    @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
-    def index_delete_node(self, index: BaseGPTVectorStoreIndex, node_id: str):
-        index.delete_node(node_id)
-
-    def del_doc(self, doc_id: str):
-        if not self._dataset.index_struct_dict:
-            return
-
-        service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id)
-
-        index = vector_store.get_index(
-            service_context=service_context,
-            index_struct=self._dataset.index_struct_dict
-        )
-
-        self.index_delete_doc(index, doc_id)
-
-    @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
-    def index_delete_doc(self, index: BaseGPTVectorStoreIndex, doc_id: str):
-        index.delete(doc_id)
-
-    @property
-    def query_index(self) -> Optional[BaseGPTVectorStoreIndex]:
-        if not self._dataset.index_struct_dict:
-            return None
-
-        service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
-
-        return vector_store.get_index(
-            service_context=service_context,
-            index_struct=self._dataset.index_struct_dict
-        )
-
-    def _filter_duplicate_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]) -> List[Node]:
-        for node in nodes:
-            node_id = node.doc_id
-            exists_duplicate_node = index.exists_by_node_id(node_id)
-            if exists_duplicate_node:
-                nodes.remove(node)
-
-        return nodes

+ 175 - 0
api/core/index/vector_index/base.py

@@ -0,0 +1,175 @@
+import json
+import logging
+from abc import abstractmethod
+from typing import List, Any, cast
+
+from langchain.embeddings.base import Embeddings
+from langchain.schema import Document, BaseRetriever
+from langchain.vectorstores import VectorStore
+from weaviate import UnexpectedStatusCodeException
+
+from core.index.base import BaseIndex
+from extensions.ext_database import db
+from models.dataset import Dataset, DocumentSegment
+from models.dataset import Document as DatasetDocument
+
+
+class BaseVectorIndex(BaseIndex):
+    
+    def __init__(self, dataset: Dataset, embeddings: Embeddings):
+        super().__init__(dataset)
+        self._embeddings = embeddings
+        self._vector_store = None
+        
+    def get_type(self) -> str:
+        raise NotImplementedError
+
+    @abstractmethod
+    def get_index_name(self, dataset: Dataset) -> str:
+        raise NotImplementedError
+
+    @abstractmethod
+    def to_index_struct(self) -> dict:
+        raise NotImplementedError
+
+    @abstractmethod
+    def _get_vector_store(self) -> VectorStore:
+        raise NotImplementedError
+
+    @abstractmethod
+    def _get_vector_store_class(self) -> type:
+        raise NotImplementedError
+
+    def search(
+            self, query: str,
+            **kwargs: Any
+    ) -> List[Document]:
+        vector_store = self._get_vector_store()
+        vector_store = cast(self._get_vector_store_class(), vector_store)
+
+        search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity'
+        search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
+
+        if search_type == 'similarity_score_threshold':
+            score_threshold = search_kwargs.get("score_threshold")
+            if (score_threshold is None) or (not isinstance(score_threshold, float)):
+                search_kwargs['score_threshold'] = .0
+
+            docs_with_similarity = vector_store.similarity_search_with_relevance_scores(
+                query, **search_kwargs
+            )
+
+            docs = []
+            for doc, similarity in docs_with_similarity:
+                doc.metadata['score'] = similarity
+                docs.append(doc)
+
+            return docs
+
+        # similarity k
+        # mmr k, fetch_k, lambda_mult
+        # similarity_score_threshold k
+        return vector_store.as_retriever(
+            search_type=search_type,
+            search_kwargs=search_kwargs
+        ).get_relevant_documents(query)
+
+    def get_retriever(self, **kwargs: Any) -> BaseRetriever:
+        vector_store = self._get_vector_store()
+        vector_store = cast(self._get_vector_store_class(), vector_store)
+
+        return vector_store.as_retriever(**kwargs)
+
+    def add_texts(self, texts: list[Document], **kwargs):
+        if self._is_origin():
+            self.recreate_dataset(self.dataset)
+
+        vector_store = self._get_vector_store()
+        vector_store = cast(self._get_vector_store_class(), vector_store)
+
+        if kwargs.get('duplicate_check', False):
+            texts = self._filter_duplicate_texts(texts)
+
+        uuids = self._get_uuids(texts)
+        vector_store.add_documents(texts, uuids=uuids)
+
+    def text_exists(self, id: str) -> bool:
+        vector_store = self._get_vector_store()
+        vector_store = cast(self._get_vector_store_class(), vector_store)
+
+        return vector_store.text_exists(id)
+
+    def delete_by_ids(self, ids: list[str]) -> None:
+        if self._is_origin():
+            self.recreate_dataset(self.dataset)
+            return
+
+        vector_store = self._get_vector_store()
+        vector_store = cast(self._get_vector_store_class(), vector_store)
+
+        for node_id in ids:
+            vector_store.del_text(node_id)
+
+    def delete(self) -> None:
+        vector_store = self._get_vector_store()
+        vector_store = cast(self._get_vector_store_class(), vector_store)
+
+        vector_store.delete()
+
+    def _is_origin(self):
+        return False
+
+    def recreate_dataset(self, dataset: Dataset):
+        logging.info(f"Recreating dataset {dataset.id}")
+
+        try:
+            self.delete()
+        except UnexpectedStatusCodeException as e:
+            if e.status_code != 400:
+                # 400 means index not exists
+                raise e
+
+        dataset_documents = db.session.query(DatasetDocument).filter(
+            DatasetDocument.dataset_id == dataset.id,
+            DatasetDocument.indexing_status == 'completed',
+            DatasetDocument.enabled == True,
+            DatasetDocument.archived == False,
+        ).all()
+
+        documents = []
+        for dataset_document in dataset_documents:
+            segments = db.session.query(DocumentSegment).filter(
+                DocumentSegment.document_id == dataset_document.id,
+                DocumentSegment.status == 'completed',
+                DocumentSegment.enabled == True
+            ).all()
+            
+            for segment in segments:
+                document = Document(
+                    page_content=segment.content,
+                    metadata={
+                        "doc_id": segment.index_node_id,
+                        "doc_hash": segment.index_node_hash,
+                        "document_id": segment.document_id,
+                        "dataset_id": segment.dataset_id,
+                    }
+                )
+
+                documents.append(document)
+
+        origin_index_struct = self.dataset.index_struct
+        self.dataset.index_struct = None
+
+        if documents:
+            try:
+                self.create(documents)
+            except Exception as e:
+                self.dataset.index_struct = origin_index_struct
+                raise e
+
+            dataset.index_struct = json.dumps(self.to_index_struct())
+
+        db.session.commit()
+
+        self.dataset = dataset
+        logging.info(f"Dataset {dataset.id} recreate successfully.")

+ 116 - 0
api/core/index/vector_index/qdrant_vector_index.py

@@ -0,0 +1,116 @@
+import os
+from typing import Optional, Any, List, cast
+
+import qdrant_client
+from langchain.embeddings.base import Embeddings
+from langchain.schema import Document, BaseRetriever
+from langchain.vectorstores import VectorStore
+from pydantic import BaseModel
+
+from core.index.base import BaseIndex
+from core.index.vector_index.base import BaseVectorIndex
+from core.vector_store.qdrant_vector_store import QdrantVectorStore
+from models.dataset import Dataset
+
+
+class QdrantConfig(BaseModel):
+    endpoint: str
+    api_key: Optional[str]
+    root_path: Optional[str]
+    
+    def to_qdrant_params(self):
+        if self.endpoint and self.endpoint.startswith('path:'):
+            path = self.endpoint.replace('path:', '')
+            if not os.path.isabs(path):
+                path = os.path.join(self.root_path, path)
+
+            return {
+                'path': path
+            }
+        else:
+            return {
+                'url': self.endpoint,
+                'api_key': self.api_key,
+            }
+
+
+class QdrantVectorIndex(BaseVectorIndex):
+    def __init__(self, dataset: Dataset, config: QdrantConfig, embeddings: Embeddings):
+        super().__init__(dataset, embeddings)
+        self._client_config = config
+
+    def get_type(self) -> str:
+        return 'qdrant'
+
+    def get_index_name(self, dataset: Dataset) -> str:
+        if self.dataset.index_struct_dict:
+            return self.dataset.index_struct_dict['vector_store']['collection_name']
+
+        dataset_id = dataset.id
+        return "Index_" + dataset_id.replace("-", "_")
+
+    def to_index_struct(self) -> dict:
+        return {
+            "type": self.get_type(),
+            "vector_store": {"collection_name": self.get_index_name(self.dataset)}
+        }
+
+    def create(self, texts: list[Document], **kwargs) -> BaseIndex:
+        uuids = self._get_uuids(texts)
+        self._vector_store = QdrantVectorStore.from_documents(
+            texts,
+            self._embeddings,
+            collection_name=self.get_index_name(self.dataset),
+            ids=uuids,
+            content_payload_key='text',
+            **self._client_config.to_qdrant_params()
+        )
+
+        return self
+
+    def _get_vector_store(self) -> VectorStore:
+        """Only for created index."""
+        if self._vector_store:
+            return self._vector_store
+        
+        client = qdrant_client.QdrantClient(
+            **self._client_config.to_qdrant_params()
+        )
+
+        return QdrantVectorStore(
+            client=client,
+            collection_name=self.get_index_name(self.dataset),
+            embeddings=self._embeddings,
+            content_payload_key='text'
+        )
+
+    def _get_vector_store_class(self) -> type:
+        return QdrantVectorStore
+
+    def delete_by_document_id(self, document_id: str):
+        if self._is_origin():
+            self.recreate_dataset(self.dataset)
+            return
+
+        vector_store = self._get_vector_store()
+        vector_store = cast(self._get_vector_store_class(), vector_store)
+
+        from qdrant_client.http import models
+
+        vector_store.del_texts(models.Filter(
+            must=[
+                models.FieldCondition(
+                    key="metadata.document_id",
+                    match=models.MatchValue(value=document_id),
+                ),
+            ],
+        ))
+
+    def _is_origin(self):
+        if self.dataset.index_struct_dict:
+            class_prefix: str = self.dataset.index_struct_dict['vector_store']['collection_name']
+            if class_prefix.startswith('Vector_'):
+                # original class_prefix
+                return True
+
+        return False

+ 69 - 0
api/core/index/vector_index/vector_index.py

@@ -0,0 +1,69 @@
+import json
+
+from flask import current_app
+from langchain.embeddings.base import Embeddings
+
+from core.index.vector_index.base import BaseVectorIndex
+from extensions.ext_database import db
+from models.dataset import Dataset, Document
+
+
+class VectorIndex:
+    def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings):
+        self._dataset = dataset
+        self._embeddings = embeddings
+        self._vector_index = self._init_vector_index(dataset, config, embeddings)
+
+    def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings) -> BaseVectorIndex:
+        vector_type = config.get('VECTOR_STORE')
+
+        if self._dataset.index_struct_dict:
+            vector_type = self._dataset.index_struct_dict['type']
+
+        if not vector_type:
+            raise ValueError(f"Vector store must be specified.")
+
+        if vector_type == "weaviate":
+            from core.index.vector_index.weaviate_vector_index import WeaviateVectorIndex, WeaviateConfig
+
+            return WeaviateVectorIndex(
+                dataset=dataset,
+                config=WeaviateConfig(
+                    endpoint=config.get('WEAVIATE_ENDPOINT'),
+                    api_key=config.get('WEAVIATE_API_KEY'),
+                    batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
+                ),
+                embeddings=embeddings
+            )
+        elif vector_type == "qdrant":
+            from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
+
+            return QdrantVectorIndex(
+                dataset=dataset,
+                config=QdrantConfig(
+                    endpoint=config.get('QDRANT_URL'),
+                    api_key=config.get('QDRANT_API_KEY'),
+                    root_path=current_app.root_path
+                ),
+                embeddings=embeddings
+            )
+        else:
+            raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
+
+    def add_texts(self, texts: list[Document], **kwargs):
+        if not self._dataset.index_struct_dict:
+            self._vector_index.create(texts, **kwargs)
+            self._dataset.index_struct = json.dumps(self._vector_index.to_index_struct())
+            db.session.commit()
+            return
+
+        self._vector_index.add_texts(texts, **kwargs)
+
+    def __getattr__(self, name):
+        if self._vector_index is not None:
+            method = getattr(self._vector_index, name)
+            if callable(method):
+                return method
+
+        raise AttributeError(f"'VectorIndex' object has no attribute '{name}'")
+

+ 132 - 0
api/core/index/vector_index/weaviate_vector_index.py

@@ -0,0 +1,132 @@
+from typing import Optional, cast
+
+import weaviate
+from langchain.embeddings.base import Embeddings
+from langchain.schema import Document, BaseRetriever
+from langchain.vectorstores import VectorStore
+from pydantic import BaseModel, root_validator
+
+from core.index.base import BaseIndex
+from core.index.vector_index.base import BaseVectorIndex
+from core.vector_store.weaviate_vector_store import WeaviateVectorStore
+from models.dataset import Dataset
+
+
+class WeaviateConfig(BaseModel):
+    endpoint: str
+    api_key: Optional[str]
+    batch_size: int = 100
+
+    @root_validator()
+    def validate_config(cls, values: dict) -> dict:
+        if not values['endpoint']:
+            raise ValueError("config WEAVIATE_ENDPOINT is required")
+        return values
+
+
+class WeaviateVectorIndex(BaseVectorIndex):
+    def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings):
+        super().__init__(dataset, embeddings)
+        self._client = self._init_client(config)
+
+    def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
+        auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
+
+        weaviate.connect.connection.has_grpc = False
+
+        client = weaviate.Client(
+            url=config.endpoint,
+            auth_client_secret=auth_config,
+            timeout_config=(5, 60),
+            startup_period=None
+        )
+
+        client.batch.configure(
+            # `batch_size` takes an `int` value to enable auto-batching
+            # (`None` is used for manual batching)
+            batch_size=config.batch_size,
+            # dynamically update the `batch_size` based on import speed
+            dynamic=True,
+            # `timeout_retries` takes an `int` value to retry on time outs
+            timeout_retries=3,
+        )
+
+        return client
+
+    def get_type(self) -> str:
+        return 'weaviate'
+
+    def get_index_name(self, dataset: Dataset) -> str:
+        if self.dataset.index_struct_dict:
+            class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
+            if not class_prefix.endswith('_Node'):
+                # original class_prefix
+                class_prefix += '_Node'
+
+            return class_prefix
+
+        dataset_id = dataset.id
+        return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
+
+    def to_index_struct(self) -> dict:
+        return {
+            "type": self.get_type(),
+            "vector_store": {"class_prefix": self.get_index_name(self.dataset)}
+        }
+
+    def create(self, texts: list[Document], **kwargs) -> BaseIndex:
+        uuids = self._get_uuids(texts)
+        self._vector_store = WeaviateVectorStore.from_documents(
+            texts,
+            self._embeddings,
+            client=self._client,
+            index_name=self.get_index_name(self.dataset),
+            uuids=uuids,
+            by_text=False
+        )
+
+        return self
+
+    def _get_vector_store(self) -> VectorStore:
+        """Only for created index."""
+        if self._vector_store:
+            return self._vector_store
+
+        attributes = ['doc_id', 'dataset_id', 'document_id']
+        if self._is_origin():
+            attributes = ['doc_id']
+
+        return WeaviateVectorStore(
+            client=self._client,
+            index_name=self.get_index_name(self.dataset),
+            text_key='text',
+            embedding=self._embeddings,
+            attributes=attributes,
+            by_text=False
+        )
+
+    def _get_vector_store_class(self) -> type:
+        return WeaviateVectorStore
+
+    def delete_by_document_id(self, document_id: str):
+        if self._is_origin():
+            self.recreate_dataset(self.dataset)
+            return
+
+        vector_store = self._get_vector_store()
+        vector_store = cast(self._get_vector_store_class(), vector_store)
+
+        vector_store.del_texts({
+            "operator": "Equal",
+            "path": ["document_id"],
+            "valueText": document_id
+        })
+
+    def _is_origin(self):
+        if self.dataset.index_struct_dict:
+            class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
+            if not class_prefix.endswith('_Node'):
+                # original class_prefix
+                return True
+
+        return False

+ 262 - 283
api/core/indexing_runner.py

@@ -1,35 +1,34 @@
 import datetime
 import datetime
 import json
 import json
+import logging
 import re
 import re
-import tempfile
 import time
 import time
-from pathlib import Path
-from typing import Optional, List
+import uuid
+from typing import Optional, List, cast
 
 
+from flask import current_app
 from flask_login import current_user
 from flask_login import current_user
-from langchain.text_splitter import RecursiveCharacterTextSplitter
+from langchain.embeddings import OpenAIEmbeddings
+from langchain.schema import Document
+from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
 
 
-from llama_index import SimpleDirectoryReader
-from llama_index.data_structs import Node
-from llama_index.data_structs.node_v2 import DocumentRelationship
-from llama_index.node_parser import SimpleNodeParser, NodeParser
-from llama_index.readers.file.base import DEFAULT_FILE_EXTRACTOR
-from llama_index.readers.file.markdown_parser import MarkdownParser
-
-from core.data_source.notion import NotionPageReader
-from core.index.readers.xlsx_parser import XLSXParser
+from core.data_loader.file_extractor import FileExtractor
+from core.data_loader.loader.notion import NotionLoader
 from core.docstore.dataset_docstore import DatesetDocumentStore
 from core.docstore.dataset_docstore import DatesetDocumentStore
-from core.index.keyword_table_index import KeywordTableIndex
-from core.index.readers.html_parser import HTMLParser
-from core.index.readers.markdown_parser import MarkdownParser
-from core.index.readers.pdf_parser import PDFParser
-from core.index.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
-from core.index.vector_index import VectorIndex
+from core.embedding.cached_embedding import CacheEmbedding
+from core.index.index import IndexBuilder
+from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
+from core.index.vector_index.vector_index import VectorIndex
+from core.llm.error import ProviderTokenNotInitError
+from core.llm.llm_builder import LLMBuilder
+from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
 from core.llm.token_calculator import TokenCalculator
 from core.llm.token_calculator import TokenCalculator
 from extensions.ext_database import db
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from extensions.ext_redis import redis_client
 from extensions.ext_storage import storage
 from extensions.ext_storage import storage
-from models.dataset import Document, Dataset, DocumentSegment, DatasetProcessRule
+from libs import helper
+from models.dataset import Document as DatasetDocument
+from models.dataset import Dataset, DocumentSegment, DatasetProcessRule
 from models.model import UploadFile
 from models.model import UploadFile
 from models.source import DataSourceBinding
 from models.source import DataSourceBinding
 
 
@@ -40,135 +39,171 @@ class IndexingRunner:
         self.storage = storage
         self.storage = storage
         self.embedding_model_name = embedding_model_name
         self.embedding_model_name = embedding_model_name
 
 
-    def run(self, documents: List[Document]):
+    def run(self, dataset_documents: List[DatasetDocument]):
         """Run the indexing process."""
         """Run the indexing process."""
-        for document in documents:
+        for dataset_document in dataset_documents:
+            try:
+                # get dataset
+                dataset = Dataset.query.filter_by(
+                    id=dataset_document.dataset_id
+                ).first()
+
+                if not dataset:
+                    raise ValueError("no dataset found")
+
+                # load file
+                text_docs = self._load_data(dataset_document)
+
+                # get the process rule
+                processing_rule = db.session.query(DatasetProcessRule). \
+                    filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
+                    first()
+
+                # get splitter
+                splitter = self._get_splitter(processing_rule)
+
+                # split to documents
+                documents = self._step_split(
+                    text_docs=text_docs,
+                    splitter=splitter,
+                    dataset=dataset,
+                    dataset_document=dataset_document,
+                    processing_rule=processing_rule
+                )
+
+                # build index
+                self._build_index(
+                    dataset=dataset,
+                    dataset_document=dataset_document,
+                    documents=documents
+                )
+            except DocumentIsPausedException:
+                raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
+            except ProviderTokenNotInitError as e:
+                dataset_document.indexing_status = 'error'
+                dataset_document.error = str(e.description)
+                dataset_document.stopped_at = datetime.datetime.utcnow()
+                db.session.commit()
+            except Exception as e:
+                logging.exception("consume document failed")
+                dataset_document.indexing_status = 'error'
+                dataset_document.error = str(e)
+                dataset_document.stopped_at = datetime.datetime.utcnow()
+                db.session.commit()
+
+    def run_in_splitting_status(self, dataset_document: DatasetDocument):
+        """Run the indexing process when the index_status is splitting."""
+        try:
             # get dataset
             # get dataset
             dataset = Dataset.query.filter_by(
             dataset = Dataset.query.filter_by(
-                id=document.dataset_id
+                id=dataset_document.dataset_id
             ).first()
             ).first()
 
 
             if not dataset:
             if not dataset:
                 raise ValueError("no dataset found")
                 raise ValueError("no dataset found")
 
 
+            # get exist document_segment list and delete
+            document_segments = DocumentSegment.query.filter_by(
+                dataset_id=dataset.id,
+                document_id=dataset_document.id
+            ).all()
+
+            db.session.delete(document_segments)
+            db.session.commit()
+
             # load file
             # load file
-            text_docs = self._load_data(document)
+            text_docs = self._load_data(dataset_document)
 
 
             # get the process rule
             # get the process rule
             processing_rule = db.session.query(DatasetProcessRule). \
             processing_rule = db.session.query(DatasetProcessRule). \
-                filter(DatasetProcessRule.id == document.dataset_process_rule_id). \
+                filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
                 first()
                 first()
 
 
-            # get node parser for splitting
-            node_parser = self._get_node_parser(processing_rule)
+            # get splitter
+            splitter = self._get_splitter(processing_rule)
 
 
-            # split to nodes
-            nodes = self._step_split(
+            # split to documents
+            documents = self._step_split(
                 text_docs=text_docs,
                 text_docs=text_docs,
-                node_parser=node_parser,
+                splitter=splitter,
                 dataset=dataset,
                 dataset=dataset,
-                document=document,
+                dataset_document=dataset_document,
                 processing_rule=processing_rule
                 processing_rule=processing_rule
             )
             )
 
 
             # build index
             # build index
             self._build_index(
             self._build_index(
                 dataset=dataset,
                 dataset=dataset,
-                document=document,
-                nodes=nodes
+                dataset_document=dataset_document,
+                documents=documents
             )
             )
+        except DocumentIsPausedException:
+            raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
+        except ProviderTokenNotInitError as e:
+            dataset_document.indexing_status = 'error'
+            dataset_document.error = str(e.description)
+            dataset_document.stopped_at = datetime.datetime.utcnow()
+            db.session.commit()
+        except Exception as e:
+            logging.exception("consume document failed")
+            dataset_document.indexing_status = 'error'
+            dataset_document.error = str(e)
+            dataset_document.stopped_at = datetime.datetime.utcnow()
+            db.session.commit()
 
 
-    def run_in_splitting_status(self, document: Document):
-        """Run the indexing process when the index_status is splitting."""
-        # get dataset
-        dataset = Dataset.query.filter_by(
-            id=document.dataset_id
-        ).first()
-
-        if not dataset:
-            raise ValueError("no dataset found")
-
-        # get exist document_segment list and delete
-        document_segments = DocumentSegment.query.filter_by(
-            dataset_id=dataset.id,
-            document_id=document.id
-        ).all()
-        db.session.delete(document_segments)
-        db.session.commit()
-        # load file
-        text_docs = self._load_data(document)
-
-        # get the process rule
-        processing_rule = db.session.query(DatasetProcessRule). \
-            filter(DatasetProcessRule.id == document.dataset_process_rule_id). \
-            first()
-
-        # get node parser for splitting
-        node_parser = self._get_node_parser(processing_rule)
+    def run_in_indexing_status(self, dataset_document: DatasetDocument):
+        """Run the indexing process when the index_status is indexing."""
+        try:
+            # get dataset
+            dataset = Dataset.query.filter_by(
+                id=dataset_document.dataset_id
+            ).first()
 
 
-        # split to nodes
-        nodes = self._step_split(
-            text_docs=text_docs,
-            node_parser=node_parser,
-            dataset=dataset,
-            document=document,
-            processing_rule=processing_rule
-        )
+            if not dataset:
+                raise ValueError("no dataset found")
 
 
-        # build index
-        self._build_index(
-            dataset=dataset,
-            document=document,
-            nodes=nodes
-        )
+            # get exist document_segment list and delete
+            document_segments = DocumentSegment.query.filter_by(
+                dataset_id=dataset.id,
+                document_id=dataset_document.id
+            ).all()
+
+            documents = []
+            if document_segments:
+                for document_segment in document_segments:
+                    # transform segment to node
+                    if document_segment.status != "completed":
+                        document = Document(
+                            page_content=document_segment.content,
+                            metadata={
+                                "doc_id": document_segment.index_node_id,
+                                "doc_hash": document_segment.index_node_hash,
+                                "document_id": document_segment.document_id,
+                                "dataset_id": document_segment.dataset_id,
+                            }
+                        )
+
+                        documents.append(document)
 
 
-    def run_in_indexing_status(self, document: Document):
-        """Run the indexing process when the index_status is indexing."""
-        # get dataset
-        dataset = Dataset.query.filter_by(
-            id=document.dataset_id
-        ).first()
-
-        if not dataset:
-            raise ValueError("no dataset found")
-
-        # get exist document_segment list and delete
-        document_segments = DocumentSegment.query.filter_by(
-            dataset_id=dataset.id,
-            document_id=document.id
-        ).all()
-        nodes = []
-        if document_segments:
-            for document_segment in document_segments:
-                # transform segment to node
-                if document_segment.status != "completed":
-                    relationships = {
-                        DocumentRelationship.SOURCE: document_segment.document_id,
-                    }
-
-                    previous_segment = document_segment.previous_segment
-                    if previous_segment:
-                        relationships[DocumentRelationship.PREVIOUS] = previous_segment.index_node_id
-
-                    next_segment = document_segment.next_segment
-                    if next_segment:
-                        relationships[DocumentRelationship.NEXT] = next_segment.index_node_id
-                    node = Node(
-                        doc_id=document_segment.index_node_id,
-                        doc_hash=document_segment.index_node_hash,
-                        text=document_segment.content,
-                        extra_info=None,
-                        node_info=None,
-                        relationships=relationships
-                    )
-                    nodes.append(node)
-
-        # build index
-        self._build_index(
-            dataset=dataset,
-            document=document,
-            nodes=nodes
-        )
+            # build index
+            self._build_index(
+                dataset=dataset,
+                dataset_document=dataset_document,
+                documents=documents
+            )
+        except DocumentIsPausedException:
+            raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
+        except ProviderTokenNotInitError as e:
+            dataset_document.indexing_status = 'error'
+            dataset_document.error = str(e.description)
+            dataset_document.stopped_at = datetime.datetime.utcnow()
+            db.session.commit()
+        except Exception as e:
+            logging.exception("consume document failed")
+            dataset_document.indexing_status = 'error'
+            dataset_document.error = str(e)
+            dataset_document.stopped_at = datetime.datetime.utcnow()
+            db.session.commit()
 
 
     def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict) -> dict:
     def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict) -> dict:
         """
         """
@@ -179,28 +214,28 @@ class IndexingRunner:
         total_segments = 0
         total_segments = 0
         for file_detail in file_details:
         for file_detail in file_details:
             # load data from file
             # load data from file
-            text_docs = self._load_data_from_file(file_detail)
+            text_docs = FileExtractor.load(file_detail)
 
 
             processing_rule = DatasetProcessRule(
             processing_rule = DatasetProcessRule(
                 mode=tmp_processing_rule["mode"],
                 mode=tmp_processing_rule["mode"],
                 rules=json.dumps(tmp_processing_rule["rules"])
                 rules=json.dumps(tmp_processing_rule["rules"])
             )
             )
 
 
-            # get node parser for splitting
-            node_parser = self._get_node_parser(processing_rule)
+            # get splitter
+            splitter = self._get_splitter(processing_rule)
 
 
-            # split to nodes
-            nodes = self._split_to_nodes(
+            # split to documents
+            documents = self._split_to_documents(
                 text_docs=text_docs,
                 text_docs=text_docs,
-                node_parser=node_parser,
+                splitter=splitter,
                 processing_rule=processing_rule
                 processing_rule=processing_rule
             )
             )
-            total_segments += len(nodes)
-            for node in nodes:
+            total_segments += len(documents)
+            for document in documents:
                 if len(preview_texts) < 5:
                 if len(preview_texts) < 5:
-                    preview_texts.append(node.get_text())
+                    preview_texts.append(document.page_content)
 
 
-                tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text())
+                tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
 
 
         return {
         return {
             "total_segments": total_segments,
             "total_segments": total_segments,
@@ -230,35 +265,36 @@ class IndexingRunner:
             ).first()
             ).first()
             if not data_source_binding:
             if not data_source_binding:
                 raise ValueError('Data source binding not found.')
                 raise ValueError('Data source binding not found.')
-            reader = NotionPageReader(integration_token=data_source_binding.access_token)
+
             for page in notion_info['pages']:
             for page in notion_info['pages']:
-                if page['type'] == 'page':
-                    page_ids = [page['page_id']]
-                    documents = reader.load_data_as_documents(page_ids=page_ids)
-                elif page['type'] == 'database':
-                    documents = reader.load_data_as_documents(database_id=page['page_id'])
-                else:
-                    documents = []
+                loader = NotionLoader(
+                    notion_access_token=data_source_binding.access_token,
+                    notion_workspace_id=workspace_id,
+                    notion_obj_id=page['page_id'],
+                    notion_page_type=page['type']
+                )
+                documents = loader.load()
+
                 processing_rule = DatasetProcessRule(
                 processing_rule = DatasetProcessRule(
                     mode=tmp_processing_rule["mode"],
                     mode=tmp_processing_rule["mode"],
                     rules=json.dumps(tmp_processing_rule["rules"])
                     rules=json.dumps(tmp_processing_rule["rules"])
                 )
                 )
 
 
-                # get node parser for splitting
-                node_parser = self._get_node_parser(processing_rule)
+                # get splitter
+                splitter = self._get_splitter(processing_rule)
 
 
-                # split to nodes
-                nodes = self._split_to_nodes(
+                # split to documents
+                documents = self._split_to_documents(
                     text_docs=documents,
                     text_docs=documents,
-                    node_parser=node_parser,
+                    splitter=splitter,
                     processing_rule=processing_rule
                     processing_rule=processing_rule
                 )
                 )
-                total_segments += len(nodes)
-                for node in nodes:
+                total_segments += len(documents)
+                for document in documents:
                     if len(preview_texts) < 5:
                     if len(preview_texts) < 5:
-                        preview_texts.append(node.get_text())
+                        preview_texts.append(document.page_content)
 
 
-                    tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text())
+                    tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
 
 
         return {
         return {
             "total_segments": total_segments,
             "total_segments": total_segments,
@@ -268,14 +304,14 @@ class IndexingRunner:
             "preview": preview_texts
             "preview": preview_texts
         }
         }
 
 
-    def _load_data(self, document: Document) -> List[Document]:
+    def _load_data(self, dataset_document: DatasetDocument) -> List[Document]:
         # load file
         # load file
-        if document.data_source_type not in ["upload_file", "notion_import"]:
+        if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
             return []
             return []
 
 
-        data_source_info = document.data_source_info_dict
+        data_source_info = dataset_document.data_source_info_dict
         text_docs = []
         text_docs = []
-        if document.data_source_type == 'upload_file':
+        if dataset_document.data_source_type == 'upload_file':
             if not data_source_info or 'upload_file_id' not in data_source_info:
             if not data_source_info or 'upload_file_id' not in data_source_info:
                 raise ValueError("no upload file found")
                 raise ValueError("no upload file found")
 
 
@@ -283,47 +319,28 @@ class IndexingRunner:
                 filter(UploadFile.id == data_source_info['upload_file_id']). \
                 filter(UploadFile.id == data_source_info['upload_file_id']). \
                 one_or_none()
                 one_or_none()
 
 
-            text_docs = self._load_data_from_file(file_detail)
-        elif document.data_source_type == 'notion_import':
-            if not data_source_info or 'notion_page_id' not in data_source_info \
-                    or 'notion_workspace_id' not in data_source_info:
-                raise ValueError("no notion page found")
-            workspace_id = data_source_info['notion_workspace_id']
-            page_id = data_source_info['notion_page_id']
-            page_type = data_source_info['type']
-            data_source_binding = DataSourceBinding.query.filter(
-                db.and_(
-                    DataSourceBinding.tenant_id == document.tenant_id,
-                    DataSourceBinding.provider == 'notion',
-                    DataSourceBinding.disabled == False,
-                    DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
-                )
-            ).first()
-            if not data_source_binding:
-                raise ValueError('Data source binding not found.')
-            if page_type == 'page':
-                # add page last_edited_time to data_source_info
-                self._get_notion_page_last_edited_time(page_id, data_source_binding.access_token, document)
-                text_docs = self._load_page_data_from_notion(page_id, data_source_binding.access_token)
-            elif page_type == 'database':
-                # add page last_edited_time to data_source_info
-                self._get_notion_database_last_edited_time(page_id, data_source_binding.access_token, document)
-                text_docs = self._load_database_data_from_notion(page_id, data_source_binding.access_token)
+            text_docs = FileExtractor.load(file_detail)
+        elif dataset_document.data_source_type == 'notion_import':
+            loader = NotionLoader.from_document(dataset_document)
+            text_docs = loader.load()
+
         # update document status to splitting
         # update document status to splitting
         self._update_document_index_status(
         self._update_document_index_status(
-            document_id=document.id,
+            document_id=dataset_document.id,
             after_indexing_status="splitting",
             after_indexing_status="splitting",
             extra_update_params={
             extra_update_params={
-                Document.word_count: sum([len(text_doc.text) for text_doc in text_docs]),
-                Document.parsing_completed_at: datetime.datetime.utcnow()
+                DatasetDocument.word_count: sum([len(text_doc.page_content) for text_doc in text_docs]),
+                DatasetDocument.parsing_completed_at: datetime.datetime.utcnow()
             }
             }
         )
         )
 
 
         # replace doc id to document model id
         # replace doc id to document model id
+        text_docs = cast(List[Document], text_docs)
         for text_doc in text_docs:
         for text_doc in text_docs:
             # remove invalid symbol
             # remove invalid symbol
-            text_doc.text = self.filter_string(text_doc.get_text())
-            text_doc.doc_id = document.id
+            text_doc.page_content = self.filter_string(text_doc.page_content)
+            text_doc.metadata['document_id'] = dataset_document.id
+            text_doc.metadata['dataset_id'] = dataset_document.dataset_id
 
 
         return text_docs
         return text_docs
 
 
@@ -331,61 +348,7 @@ class IndexingRunner:
         pattern = re.compile('[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]')
         pattern = re.compile('[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]')
         return pattern.sub('', text)
         return pattern.sub('', text)
 
 
-    def _load_data_from_file(self, upload_file: UploadFile) -> List[Document]:
-        with tempfile.TemporaryDirectory() as temp_dir:
-            suffix = Path(upload_file.key).suffix
-            filepath = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
-            self.storage.download(upload_file.key, filepath)
-
-            file_extractor = DEFAULT_FILE_EXTRACTOR.copy()
-            file_extractor[".markdown"] = MarkdownParser()
-            file_extractor[".md"] = MarkdownParser()
-            file_extractor[".html"] = HTMLParser()
-            file_extractor[".htm"] = HTMLParser()
-            file_extractor[".pdf"] = PDFParser({'upload_file': upload_file})
-            file_extractor[".xlsx"] = XLSXParser()
-
-            loader = SimpleDirectoryReader(input_files=[filepath], file_extractor=file_extractor)
-            text_docs = loader.load_data()
-
-            return text_docs
-
-    def _load_page_data_from_notion(self, page_id: str, access_token: str) -> List[Document]:
-        page_ids = [page_id]
-        reader = NotionPageReader(integration_token=access_token)
-        text_docs = reader.load_data_as_documents(page_ids=page_ids)
-        return text_docs
-
-    def _load_database_data_from_notion(self, database_id: str, access_token: str) -> List[Document]:
-        reader = NotionPageReader(integration_token=access_token)
-        text_docs = reader.load_data_as_documents(database_id=database_id)
-        return text_docs
-
-    def _get_notion_page_last_edited_time(self, page_id: str, access_token: str, document: Document):
-        reader = NotionPageReader(integration_token=access_token)
-        last_edited_time = reader.get_page_last_edited_time(page_id)
-        data_source_info = document.data_source_info_dict
-        data_source_info['last_edited_time'] = last_edited_time
-        update_params = {
-            Document.data_source_info: json.dumps(data_source_info)
-        }
-
-        Document.query.filter_by(id=document.id).update(update_params)
-        db.session.commit()
-
-    def _get_notion_database_last_edited_time(self, page_id: str, access_token: str, document: Document):
-        reader = NotionPageReader(integration_token=access_token)
-        last_edited_time = reader.get_database_last_edited_time(page_id)
-        data_source_info = document.data_source_info_dict
-        data_source_info['last_edited_time'] = last_edited_time
-        update_params = {
-            Document.data_source_info: json.dumps(data_source_info)
-        }
-
-        Document.query.filter_by(id=document.id).update(update_params)
-        db.session.commit()
-
-    def _get_node_parser(self, processing_rule: DatasetProcessRule) -> NodeParser:
+    def _get_splitter(self, processing_rule: DatasetProcessRule) -> TextSplitter:
         """
         """
         Get the NodeParser object according to the processing rule.
         Get the NodeParser object according to the processing rule.
         """
         """
@@ -414,68 +377,83 @@ class IndexingRunner:
                 separators=["\n\n", "。", ".", " ", ""]
                 separators=["\n\n", "。", ".", " ", ""]
             )
             )
 
 
-        return SimpleNodeParser(text_splitter=character_splitter, include_extra_info=True)
+        return character_splitter
 
 
-    def _step_split(self, text_docs: List[Document], node_parser: NodeParser,
-                    dataset: Dataset, document: Document, processing_rule: DatasetProcessRule) -> List[Node]:
+    def _step_split(self, text_docs: List[Document], splitter: TextSplitter,
+                    dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \
+            -> List[Document]:
         """
         """
-        Split the text documents into nodes and save them to the document segment.
+        Split the text documents into documents and save them to the document segment.
         """
         """
-        nodes = self._split_to_nodes(
+        documents = self._split_to_documents(
             text_docs=text_docs,
             text_docs=text_docs,
-            node_parser=node_parser,
+            splitter=splitter,
             processing_rule=processing_rule
             processing_rule=processing_rule
         )
         )
 
 
         # save node to document segment
         # save node to document segment
         doc_store = DatesetDocumentStore(
         doc_store = DatesetDocumentStore(
             dataset=dataset,
             dataset=dataset,
-            user_id=document.created_by,
+            user_id=dataset_document.created_by,
             embedding_model_name=self.embedding_model_name,
             embedding_model_name=self.embedding_model_name,
-            document_id=document.id
+            document_id=dataset_document.id
         )
         )
+
         # add document segments
         # add document segments
-        doc_store.add_documents(nodes)
+        doc_store.add_documents(documents)
 
 
         # update document status to indexing
         # update document status to indexing
         cur_time = datetime.datetime.utcnow()
         cur_time = datetime.datetime.utcnow()
         self._update_document_index_status(
         self._update_document_index_status(
-            document_id=document.id,
+            document_id=dataset_document.id,
             after_indexing_status="indexing",
             after_indexing_status="indexing",
             extra_update_params={
             extra_update_params={
-                Document.cleaning_completed_at: cur_time,
-                Document.splitting_completed_at: cur_time,
+                DatasetDocument.cleaning_completed_at: cur_time,
+                DatasetDocument.splitting_completed_at: cur_time,
             }
             }
         )
         )
 
 
         # update segment status to indexing
         # update segment status to indexing
         self._update_segments_by_document(
         self._update_segments_by_document(
-            document_id=document.id,
+            dataset_document_id=dataset_document.id,
             update_params={
             update_params={
                 DocumentSegment.status: "indexing",
                 DocumentSegment.status: "indexing",
                 DocumentSegment.indexing_at: datetime.datetime.utcnow()
                 DocumentSegment.indexing_at: datetime.datetime.utcnow()
             }
             }
         )
         )
 
 
-        return nodes
+        return documents
 
 
-    def _split_to_nodes(self, text_docs: List[Document], node_parser: NodeParser,
-                        processing_rule: DatasetProcessRule) -> List[Node]:
+    def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,
+                        processing_rule: DatasetProcessRule) -> List[Document]:
         """
         """
         Split the text documents into nodes.
         Split the text documents into nodes.
         """
         """
-        all_nodes = []
+        all_documents = []
         for text_doc in text_docs:
         for text_doc in text_docs:
             # document clean
             # document clean
-            document_text = self._document_clean(text_doc.get_text(), processing_rule)
-            text_doc.text = document_text
+            document_text = self._document_clean(text_doc.page_content, processing_rule)
+            text_doc.page_content = document_text
 
 
             # parse document to nodes
             # parse document to nodes
-            nodes = node_parser.get_nodes_from_documents([text_doc])
-            nodes = [node for node in nodes if node.text is not None and node.text.strip()]
-            all_nodes.extend(nodes)
+            documents = splitter.split_documents([text_doc])
+
+            split_documents = []
+            for document in documents:
+                if document.page_content is None or not document.page_content.strip():
+                    continue
+
+                doc_id = str(uuid.uuid4())
+                hash = helper.generate_text_hash(document.page_content)
+
+                document.metadata['doc_id'] = doc_id
+                document.metadata['doc_hash'] = hash
+
+                split_documents.append(document)
+
+            all_documents.extend(split_documents)
 
 
-        return all_nodes
+        return all_documents
 
 
     def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str:
     def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str:
         """
         """
@@ -506,37 +484,38 @@ class IndexingRunner:
 
 
         return text
         return text
 
 
-    def _build_index(self, dataset: Dataset, document: Document, nodes: List[Node]) -> None:
+    def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None:
         """
         """
         Build the index for the document.
         Build the index for the document.
         """
         """
-        vector_index = VectorIndex(dataset=dataset)
-        keyword_table_index = KeywordTableIndex(dataset=dataset)
+        vector_index = IndexBuilder.get_index(dataset, 'high_quality')
+        keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
 
 
         # chunk nodes by chunk size
         # chunk nodes by chunk size
         indexing_start_at = time.perf_counter()
         indexing_start_at = time.perf_counter()
         tokens = 0
         tokens = 0
         chunk_size = 100
         chunk_size = 100
-        for i in range(0, len(nodes), chunk_size):
+        for i in range(0, len(documents), chunk_size):
             # check document is paused
             # check document is paused
-            self._check_document_paused_status(document.id)
-            chunk_nodes = nodes[i:i + chunk_size]
+            self._check_document_paused_status(dataset_document.id)
+            chunk_documents = documents[i:i + chunk_size]
 
 
             tokens += sum(
             tokens += sum(
-                TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) for node in chunk_nodes
+                TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
+                for document in chunk_documents
             )
             )
 
 
             # save vector index
             # save vector index
-            if dataset.indexing_technique == "high_quality":
-                vector_index.add_nodes(chunk_nodes)
+            if vector_index:
+                vector_index.add_texts(chunk_documents)
 
 
             # save keyword index
             # save keyword index
-            keyword_table_index.add_nodes(chunk_nodes)
+            keyword_table_index.add_texts(chunk_documents)
 
 
-            node_ids = [node.doc_id for node in chunk_nodes]
+            document_ids = [document.metadata['doc_id'] for document in chunk_documents]
             db.session.query(DocumentSegment).filter(
             db.session.query(DocumentSegment).filter(
-                DocumentSegment.document_id == document.id,
-                DocumentSegment.index_node_id.in_(node_ids),
+                DocumentSegment.document_id == dataset_document.id,
+                DocumentSegment.index_node_id.in_(document_ids),
                 DocumentSegment.status == "indexing"
                 DocumentSegment.status == "indexing"
             ).update({
             ).update({
                 DocumentSegment.status: "completed",
                 DocumentSegment.status: "completed",
@@ -549,12 +528,12 @@ class IndexingRunner:
 
 
         # update document status to completed
         # update document status to completed
         self._update_document_index_status(
         self._update_document_index_status(
-            document_id=document.id,
+            document_id=dataset_document.id,
             after_indexing_status="completed",
             after_indexing_status="completed",
             extra_update_params={
             extra_update_params={
-                Document.tokens: tokens,
-                Document.completed_at: datetime.datetime.utcnow(),
-                Document.indexing_latency: indexing_end_at - indexing_start_at,
+                DatasetDocument.tokens: tokens,
+                DatasetDocument.completed_at: datetime.datetime.utcnow(),
+                DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
             }
             }
         )
         )
 
 
@@ -569,25 +548,25 @@ class IndexingRunner:
         """
         """
         Update the document indexing status.
         Update the document indexing status.
         """
         """
-        count = Document.query.filter_by(id=document_id, is_paused=True).count()
+        count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
         if count > 0:
         if count > 0:
             raise DocumentIsPausedException()
             raise DocumentIsPausedException()
 
 
         update_params = {
         update_params = {
-            Document.indexing_status: after_indexing_status
+            DatasetDocument.indexing_status: after_indexing_status
         }
         }
 
 
         if extra_update_params:
         if extra_update_params:
             update_params.update(extra_update_params)
             update_params.update(extra_update_params)
 
 
-        Document.query.filter_by(id=document_id).update(update_params)
+        DatasetDocument.query.filter_by(id=document_id).update(update_params)
         db.session.commit()
         db.session.commit()
 
 
-    def _update_segments_by_document(self, document_id: str, update_params: dict) -> None:
+    def _update_segments_by_document(self, dataset_document_id: str, update_params: dict) -> None:
         """
         """
         Update the document segment by document id.
         Update the document segment by document id.
         """
         """
-        DocumentSegment.query.filter_by(document_id=document_id).update(update_params)
+        DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
         db.session.commit()
         db.session.commit()
 
 
 
 

+ 17 - 14
api/core/llm/llm_builder.py

@@ -1,7 +1,6 @@
-from typing import Union, Optional
+from typing import Union, Optional, List
 
 
-from langchain.callbacks import CallbackManager
-from langchain.llms.fake import FakeListLLM
+from langchain.callbacks.base import BaseCallbackHandler
 
 
 from core.constant import llm_constant
 from core.constant import llm_constant
 from core.llm.error import ProviderTokenNotInitError
 from core.llm.error import ProviderTokenNotInitError
@@ -32,12 +31,11 @@ class LLMBuilder:
     """
     """
 
 
     @classmethod
     @classmethod
-    def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI, FakeListLLM]:
-        if model_name == 'fake':
-            return FakeListLLM(responses=[])
-
+    def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
         provider = cls.get_default_provider(tenant_id)
         provider = cls.get_default_provider(tenant_id)
 
 
+        model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
+
         mode = cls.get_mode_by_model(model_name)
         mode = cls.get_mode_by_model(model_name)
         if mode == 'chat':
         if mode == 'chat':
             if provider == 'openai':
             if provider == 'openai':
@@ -52,16 +50,21 @@ class LLMBuilder:
         else:
         else:
             raise ValueError(f"model name {model_name} is not supported.")
             raise ValueError(f"model name {model_name} is not supported.")
 
 
-        model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
+
+        model_kwargs = {
+            'top_p': kwargs.get('top_p', 1),
+            'frequency_penalty': kwargs.get('frequency_penalty', 0),
+            'presence_penalty': kwargs.get('presence_penalty', 0),
+        }
+
+        model_extras_kwargs = model_kwargs if mode == 'completion' else {'model_kwargs': model_kwargs}
 
 
         return llm_cls(
         return llm_cls(
             model_name=model_name,
             model_name=model_name,
             temperature=kwargs.get('temperature', 0),
             temperature=kwargs.get('temperature', 0),
             max_tokens=kwargs.get('max_tokens', 256),
             max_tokens=kwargs.get('max_tokens', 256),
-            top_p=kwargs.get('top_p', 1),
-            frequency_penalty=kwargs.get('frequency_penalty', 0),
-            presence_penalty=kwargs.get('presence_penalty', 0),
-            callback_manager=kwargs.get('callback_manager', None),
+            **model_extras_kwargs,
+            callbacks=kwargs.get('callbacks', None),
             streaming=kwargs.get('streaming', False),
             streaming=kwargs.get('streaming', False),
             # request_timeout=None
             # request_timeout=None
             **model_credentials
             **model_credentials
@@ -69,7 +72,7 @@ class LLMBuilder:
 
 
     @classmethod
     @classmethod
     def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
     def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
-                          callback_manager: Optional[CallbackManager] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
+                          callbacks: Optional[List[BaseCallbackHandler]] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
         model_name = model.get("name")
         model_name = model.get("name")
         completion_params = model.get("completion_params", {})
         completion_params = model.get("completion_params", {})
 
 
@@ -82,7 +85,7 @@ class LLMBuilder:
             frequency_penalty=completion_params.get('frequency_penalty', 0.1),
             frequency_penalty=completion_params.get('frequency_penalty', 0.1),
             presence_penalty=completion_params.get('presence_penalty', 0.1),
             presence_penalty=completion_params.get('presence_penalty', 0.1),
             streaming=streaming,
             streaming=streaming,
-            callback_manager=callback_manager
+            callbacks=callbacks
         )
         )
 
 
     @classmethod
     @classmethod

+ 4 - 1
api/core/llm/provider/azure_provider.py

@@ -42,7 +42,10 @@ class AzureProvider(BaseProvider):
         """
         """
         config = self.get_provider_api_key(model_id=model_id)
         config = self.get_provider_api_key(model_id=model_id)
         config['openai_api_type'] = 'azure'
         config['openai_api_type'] = 'azure'
-        config['deployment_name'] = model_id.replace('.', '') if model_id else None
+        if model_id == 'text-embedding-ada-002':
+            config['deployment'] = model_id.replace('.', '') if model_id else None
+        else:
+            config['deployment_name'] = model_id.replace('.', '') if model_id else None
         return config
         return config
 
 
     def get_provider_name(self):
     def get_provider_name(self):

+ 13 - 50
api/core/llm/streamable_azure_chat_open_ai.py

@@ -1,3 +1,4 @@
+from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun, Callbacks
 from langchain.schema import BaseMessage, ChatResult, LLMResult
 from langchain.schema import BaseMessage, ChatResult, LLMResult
 from langchain.chat_models import AzureChatOpenAI
 from langchain.chat_models import AzureChatOpenAI
 from typing import Optional, List, Dict, Any
 from typing import Optional, List, Dict, Any
@@ -68,60 +69,22 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
 
 
         return message_tokens
         return message_tokens
 
 
-    def _generate(
-            self, messages: List[BaseMessage], stop: Optional[List[str]] = None
-    ) -> ChatResult:
-        self.callback_manager.on_llm_start(
-            {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
-            verbose=self.verbose
-        )
-
-        chat_result = super()._generate(messages, stop)
-
-        result = LLMResult(
-            generations=[chat_result.generations],
-            llm_output=chat_result.llm_output
-        )
-        self.callback_manager.on_llm_end(result, verbose=self.verbose)
-
-        return chat_result
-
-    async def _agenerate(
-            self, messages: List[BaseMessage], stop: Optional[List[str]] = None
-    ) -> ChatResult:
-        if self.callback_manager.is_async:
-            await self.callback_manager.on_llm_start(
-                {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
-                verbose=self.verbose
-            )
-        else:
-            self.callback_manager.on_llm_start(
-                {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
-                verbose=self.verbose
-            )
-
-        chat_result = super()._generate(messages, stop)
-
-        result = LLMResult(
-            generations=[chat_result.generations],
-            llm_output=chat_result.llm_output
-        )
-
-        if self.callback_manager.is_async:
-            await self.callback_manager.on_llm_end(result, verbose=self.verbose)
-        else:
-            self.callback_manager.on_llm_end(result, verbose=self.verbose)
-
-        return chat_result
-
     @handle_llm_exceptions
     @handle_llm_exceptions
     def generate(
     def generate(
-            self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
+            self,
+            messages: List[List[BaseMessage]],
+            stop: Optional[List[str]] = None,
+            callbacks: Callbacks = None,
+            **kwargs: Any,
     ) -> LLMResult:
     ) -> LLMResult:
-        return super().generate(messages, stop)
+        return super().generate(messages, stop, callbacks, **kwargs)
 
 
     @handle_llm_exceptions_async
     @handle_llm_exceptions_async
     async def agenerate(
     async def agenerate(
-            self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
+            self,
+            messages: List[List[BaseMessage]],
+            stop: Optional[List[str]] = None,
+            callbacks: Callbacks = None,
+            **kwargs: Any,
     ) -> LLMResult:
     ) -> LLMResult:
-        return await super().agenerate(messages, stop)
+        return await super().agenerate(messages, stop, callbacks, **kwargs)

+ 13 - 6
api/core/llm/streamable_azure_open_ai.py

@@ -1,5 +1,4 @@
-import os
-
+from langchain.callbacks.manager import Callbacks
 from langchain.llms import AzureOpenAI
 from langchain.llms import AzureOpenAI
 from langchain.schema import LLMResult
 from langchain.schema import LLMResult
 from typing import Optional, List, Dict, Mapping, Any
 from typing import Optional, List, Dict, Mapping, Any
@@ -53,12 +52,20 @@ class StreamableAzureOpenAI(AzureOpenAI):
 
 
     @handle_llm_exceptions
     @handle_llm_exceptions
     def generate(
     def generate(
-            self, prompts: List[str], stop: Optional[List[str]] = None
+            self,
+            prompts: List[str],
+            stop: Optional[List[str]] = None,
+            callbacks: Callbacks = None,
+            **kwargs: Any,
     ) -> LLMResult:
     ) -> LLMResult:
-        return super().generate(prompts, stop)
+        return super().generate(prompts, stop, callbacks, **kwargs)
 
 
     @handle_llm_exceptions_async
     @handle_llm_exceptions_async
     async def agenerate(
     async def agenerate(
-            self, prompts: List[str], stop: Optional[List[str]] = None
+            self,
+            prompts: List[str],
+            stop: Optional[List[str]] = None,
+            callbacks: Callbacks = None,
+            **kwargs: Any,
     ) -> LLMResult:
     ) -> LLMResult:
-        return await super().agenerate(prompts, stop)
+        return await super().agenerate(prompts, stop, callbacks, **kwargs)

+ 14 - 48
api/core/llm/streamable_chat_open_ai.py

@@ -1,6 +1,7 @@
 import os
 import os
 
 
-from langchain.schema import BaseMessage, ChatResult, LLMResult
+from langchain.callbacks.manager import Callbacks
+from langchain.schema import BaseMessage, LLMResult
 from langchain.chat_models import ChatOpenAI
 from langchain.chat_models import ChatOpenAI
 from typing import Optional, List, Dict, Any
 from typing import Optional, List, Dict, Any
 
 
@@ -70,57 +71,22 @@ class StreamableChatOpenAI(ChatOpenAI):
 
 
         return message_tokens
         return message_tokens
 
 
-    def _generate(
-        self, messages: List[BaseMessage], stop: Optional[List[str]] = None
-    ) -> ChatResult:
-        self.callback_manager.on_llm_start(
-            {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
-        )
-
-        chat_result = super()._generate(messages, stop)
-
-        result = LLMResult(
-            generations=[chat_result.generations],
-            llm_output=chat_result.llm_output
-        )
-        self.callback_manager.on_llm_end(result, verbose=self.verbose)
-
-        return chat_result
-
-    async def _agenerate(
-        self, messages: List[BaseMessage], stop: Optional[List[str]] = None
-    ) -> ChatResult:
-        if self.callback_manager.is_async:
-            await self.callback_manager.on_llm_start(
-                {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
-            )
-        else:
-            self.callback_manager.on_llm_start(
-                {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
-            )
-
-        chat_result = super()._generate(messages, stop)
-
-        result = LLMResult(
-            generations=[chat_result.generations],
-            llm_output=chat_result.llm_output
-        )
-
-        if self.callback_manager.is_async:
-            await self.callback_manager.on_llm_end(result, verbose=self.verbose)
-        else:
-            self.callback_manager.on_llm_end(result, verbose=self.verbose)
-
-        return chat_result
-
     @handle_llm_exceptions
     @handle_llm_exceptions
     def generate(
     def generate(
-            self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
+            self,
+            messages: List[List[BaseMessage]],
+            stop: Optional[List[str]] = None,
+            callbacks: Callbacks = None,
+            **kwargs: Any,
     ) -> LLMResult:
     ) -> LLMResult:
-        return super().generate(messages, stop)
+        return super().generate(messages, stop, callbacks, **kwargs)
 
 
     @handle_llm_exceptions_async
     @handle_llm_exceptions_async
     async def agenerate(
     async def agenerate(
-            self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
+            self,
+            messages: List[List[BaseMessage]],
+            stop: Optional[List[str]] = None,
+            callbacks: Callbacks = None,
+            **kwargs: Any,
     ) -> LLMResult:
     ) -> LLMResult:
-        return await super().agenerate(messages, stop)
+        return await super().agenerate(messages, stop, callbacks, **kwargs)

+ 13 - 5
api/core/llm/streamable_open_ai.py

@@ -1,5 +1,6 @@
 import os
 import os
 
 
+from langchain.callbacks.manager import Callbacks
 from langchain.schema import LLMResult
 from langchain.schema import LLMResult
 from typing import Optional, List, Dict, Any, Mapping
 from typing import Optional, List, Dict, Any, Mapping
 from langchain import OpenAI
 from langchain import OpenAI
@@ -48,15 +49,22 @@ class StreamableOpenAI(OpenAI):
             "organization": self.openai_organization if self.openai_organization else None,
             "organization": self.openai_organization if self.openai_organization else None,
         }}
         }}
 
 
-
     @handle_llm_exceptions
     @handle_llm_exceptions
     def generate(
     def generate(
-            self, prompts: List[str], stop: Optional[List[str]] = None
+            self,
+            prompts: List[str],
+            stop: Optional[List[str]] = None,
+            callbacks: Callbacks = None,
+            **kwargs: Any,
     ) -> LLMResult:
     ) -> LLMResult:
-        return super().generate(prompts, stop)
+        return super().generate(prompts, stop, callbacks, **kwargs)
 
 
     @handle_llm_exceptions_async
     @handle_llm_exceptions_async
     async def agenerate(
     async def agenerate(
-            self, prompts: List[str], stop: Optional[List[str]] = None
+            self,
+            prompts: List[str],
+            stop: Optional[List[str]] = None,
+            callbacks: Callbacks = None,
+            **kwargs: Any,
     ) -> LLMResult:
     ) -> LLMResult:
-        return await super().agenerate(prompts, stop)
+        return await super().agenerate(prompts, stop, callbacks, **kwargs)

+ 1 - 1
api/core/memory/read_only_conversation_token_db_string_buffer_shared_memory.py

@@ -1,7 +1,7 @@
 from typing import Any, List, Dict
 from typing import Any, List, Dict
 
 
 from langchain.memory.chat_memory import BaseChatMemory
 from langchain.memory.chat_memory import BaseChatMemory
-from langchain.schema import get_buffer_string, BaseMessage, BaseLanguageModel
+from langchain.schema import get_buffer_string
 
 
 from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
 from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
     ReadOnlyConversationTokenDBBufferSharedMemory
     ReadOnlyConversationTokenDBBufferSharedMemory

+ 0 - 19
api/core/prompt/prompts.py

@@ -1,5 +1,3 @@
-from llama_index import QueryKeywordExtractPrompt
-
 CONVERSATION_TITLE_PROMPT = (
 CONVERSATION_TITLE_PROMPT = (
     "Human:{query}\n-----\n"
     "Human:{query}\n-----\n"
     "Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.\n"
     "Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.\n"
@@ -45,23 +43,6 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
     "[\"question1\",\"question2\",\"question3\"]\n"
     "[\"question1\",\"question2\",\"question3\"]\n"
 )
 )
 
 
-QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL = (
-    "A question is provided below. Given the question, extract up to {max_keywords} "
-    "keywords from the text. Focus on extracting the keywords that we can use "
-    "to best lookup answers to the question. Avoid stopwords."
-    "I am not sure which language the following question is in. "
-    "If the user asked the question in Chinese, please return the keywords in Chinese. "
-    "If the user asked the question in English, please return the keywords in English.\n"
-    "---------------------\n"
-    "{question}\n"
-    "---------------------\n"
-    "Provide keywords in the following comma-separated format: 'KEYWORDS: <keywords>'\n"
-)
-
-QUERY_KEYWORD_EXTRACT_TEMPLATE = QueryKeywordExtractPrompt(
-    QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL
-)
-
 RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \
 RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \
 the model prompt that best suits the input. 
 the model prompt that best suits the input. 
 You will be provided with the prompt, variables, and an opening statement. 
 You will be provided with the prompt, variables, and an opening statement. 

+ 0 - 0
api/core/index/spiltter/fixed_text_splitter.py → api/core/spiltter/fixed_text_splitter.py


+ 87 - 0
api/core/tool/dataset_index_tool.py

@@ -0,0 +1,87 @@
+from flask import current_app
+from langchain.embeddings import OpenAIEmbeddings
+from langchain.tools import BaseTool
+
+from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
+from core.embedding.cached_embedding import CacheEmbedding
+from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
+from core.index.vector_index.vector_index import VectorIndex
+from core.llm.llm_builder import LLMBuilder
+from models.dataset import Dataset
+
+
+class DatasetTool(BaseTool):
+    """Tool for querying a Dataset."""
+
+    dataset: Dataset
+    k: int = 2
+
+    def _run(self, tool_input: str) -> str:
+        if self.dataset.indexing_technique == "economy":
+            # use keyword table query
+            kw_table_index = KeywordTableIndex(
+                dataset=self.dataset,
+                config=KeywordTableConfig(
+                    max_keywords_per_chunk=5
+                )
+            )
+
+            documents = kw_table_index.search(tool_input, search_kwargs={'k': self.k})
+        else:
+            model_credentials = LLMBuilder.get_model_credentials(
+                tenant_id=self.dataset.tenant_id,
+                model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
+                model_name='text-embedding-ada-002'
+            )
+
+            embeddings = CacheEmbedding(OpenAIEmbeddings(
+                **model_credentials
+            ))
+
+            vector_index = VectorIndex(
+                dataset=self.dataset,
+                config=current_app.config,
+                embeddings=embeddings
+            )
+
+            documents = vector_index.search(
+                tool_input,
+                search_type='similarity',
+                search_kwargs={
+                    'k': self.k
+                }
+            )
+
+            hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
+            hit_callback.on_tool_end(documents)
+
+        return str("\n".join([document.page_content for document in documents]))
+
+    async def _arun(self, tool_input: str) -> str:
+        model_credentials = LLMBuilder.get_model_credentials(
+            tenant_id=self.dataset.tenant_id,
+            model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
+            model_name='text-embedding-ada-002'
+        )
+
+        embeddings = CacheEmbedding(OpenAIEmbeddings(
+            **model_credentials
+        ))
+
+        vector_index = VectorIndex(
+            dataset=self.dataset,
+            config=current_app.config,
+            embeddings=embeddings
+        )
+
+        documents = await vector_index.asearch(
+            tool_input,
+            search_type='similarity',
+            search_kwargs={
+                'k': 10
+            }
+        )
+
+        hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
+        hit_callback.on_tool_end(documents)
+        return str("\n".join([document.page_content for document in documents]))

+ 0 - 73
api/core/tool/dataset_tool_builder.py

@@ -1,73 +0,0 @@
-from typing import Optional
-
-from langchain.callbacks import CallbackManager
-from llama_index.langchain_helpers.agents import IndexToolConfig
-
-from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
-from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
-from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
-from core.index.keyword_table_index import KeywordTableIndex
-from core.index.vector_index import VectorIndex
-from core.prompt.prompts import QUERY_KEYWORD_EXTRACT_TEMPLATE
-from core.tool.llama_index_tool import EnhanceLlamaIndexTool
-from models.dataset import Dataset
-
-
-class DatasetToolBuilder:
-    @classmethod
-    def build_dataset_tool(cls, dataset: Dataset,
-                           response_mode: str = "no_synthesizer",
-                           callback_handler: Optional[DatasetToolCallbackHandler] = None):
-        if dataset.indexing_technique == "economy":
-            # use keyword table query
-            index = KeywordTableIndex(dataset=dataset).query_index
-
-            if not index:
-                return None
-
-            query_kwargs = {
-                "mode": "default",
-                "response_mode": response_mode,
-                "query_keyword_extract_template": QUERY_KEYWORD_EXTRACT_TEMPLATE,
-                "max_keywords_per_query": 5,
-                # If num_chunks_per_query is too large,
-                # it will slow down the synthesis process due to multiple iterations of refinement.
-                "num_chunks_per_query": 2
-            }
-        else:
-            index = VectorIndex(dataset=dataset).query_index
-
-            if not index:
-                return None
-
-            query_kwargs = {
-                "mode": "default",
-                "response_mode": response_mode,
-                # If top_k is too large,
-                # it will slow down the synthesis process due to multiple iterations of refinement.
-                "similarity_top_k": 2
-            }
-
-        # fulfill description when it is empty
-        description = dataset.description
-        if not description:
-            description = 'useful for when you want to answer queries about the ' + dataset.name
-
-        index_tool_config = IndexToolConfig(
-            index=index,
-            name=f"dataset-{dataset.id}",
-            description=description,
-            index_query_kwargs=query_kwargs,
-            tool_kwargs={
-                "callback_manager": CallbackManager([callback_handler, DifyStdOutCallbackHandler()])
-            },
-            # tool_kwargs={"return_direct": True},
-            # return_direct: Whether to return LLM results directly or process the output data with an Output Parser
-        )
-
-        index_callback_handler = DatasetIndexToolCallbackHandler(dataset_id=dataset.id)
-
-        return EnhanceLlamaIndexTool.from_tool_config(
-            tool_config=index_tool_config,
-            callback_handler=index_callback_handler
-        )

+ 0 - 43
api/core/tool/llama_index_tool.py

@@ -1,43 +0,0 @@
-from typing import Dict
-
-from langchain.tools import BaseTool
-from llama_index.indices.base import BaseGPTIndex
-from llama_index.langchain_helpers.agents import IndexToolConfig
-from pydantic import Field
-
-from core.callback_handler.index_tool_callback_handler import IndexToolCallbackHandler
-
-
-class EnhanceLlamaIndexTool(BaseTool):
-    """Tool for querying a LlamaIndex."""
-
-    # NOTE: name/description still needs to be set
-    index: BaseGPTIndex
-    query_kwargs: Dict = Field(default_factory=dict)
-    return_sources: bool = False
-    callback_handler: IndexToolCallbackHandler
-
-    @classmethod
-    def from_tool_config(cls, tool_config: IndexToolConfig,
-                         callback_handler: IndexToolCallbackHandler) -> "EnhanceLlamaIndexTool":
-        """Create a tool from a tool config."""
-        return_sources = tool_config.tool_kwargs.pop("return_sources", False)
-        return cls(
-            index=tool_config.index,
-            callback_handler=callback_handler,
-            name=tool_config.name,
-            description=tool_config.description,
-            return_sources=return_sources,
-            query_kwargs=tool_config.index_query_kwargs,
-            **tool_config.tool_kwargs,
-        )
-
-    def _run(self, tool_input: str) -> str:
-        response = self.index.query(tool_input, **self.query_kwargs)
-        self.callback_handler.on_tool_end(response)
-        return str(response)
-
-    async def _arun(self, tool_input: str) -> str:
-        response = await self.index.aquery(tool_input, **self.query_kwargs)
-        self.callback_handler.on_tool_end(response)
-        return str(response)

+ 0 - 34
api/core/vector_store/base.py

@@ -1,34 +0,0 @@
-from abc import ABC, abstractmethod
-from typing import Optional
-
-from llama_index import ServiceContext, GPTVectorStoreIndex
-from llama_index.data_structs import Node
-from llama_index.vector_stores.types import VectorStore
-
-
-class BaseVectorStoreClient(ABC):
-    @abstractmethod
-    def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
-        raise NotImplementedError
-
-    @abstractmethod
-    def to_index_config(self, index_id: str) -> dict:
-        raise NotImplementedError
-
-
-class BaseGPTVectorStoreIndex(GPTVectorStoreIndex):
-    def delete_node(self, node_id: str):
-        self._vector_store.delete_node(node_id)
-
-    def exists_by_node_id(self, node_id: str) -> bool:
-        return self._vector_store.exists_by_node_id(node_id)
-
-
-class EnhanceVectorStore(ABC):
-    @abstractmethod
-    def delete_node(self, node_id: str):
-        pass
-
-    @abstractmethod
-    def exists_by_node_id(self, node_id: str) -> bool:
-        pass

+ 69 - 0
api/core/vector_store/qdrant_vector_store.py

@@ -0,0 +1,69 @@
+from typing import cast, Any
+
+from langchain.schema import Document
+from langchain.vectorstores import Qdrant
+from qdrant_client.http.models import Filter, PointIdsList, FilterSelector
+from qdrant_client.local.qdrant_local import QdrantLocal
+
+
+class QdrantVectorStore(Qdrant):
+    def del_texts(self, filter: Filter):
+        if not filter:
+            raise ValueError('filter must not be empty')
+
+        self._reload_if_needed()
+
+        self.client.delete(
+            collection_name=self.collection_name,
+            points_selector=FilterSelector(
+                filter=filter
+            ),
+        )
+
+    def del_text(self, uuid: str) -> None:
+        self._reload_if_needed()
+
+        self.client.delete(
+            collection_name=self.collection_name,
+            points_selector=PointIdsList(
+                points=[uuid],
+            ),
+        )
+
+    def text_exists(self, uuid: str) -> bool:
+        self._reload_if_needed()
+
+        response = self.client.retrieve(
+            collection_name=self.collection_name,
+            ids=[uuid]
+        )
+
+        return len(response) > 0
+
+    def delete(self):
+        self._reload_if_needed()
+
+        self.client.delete_collection(collection_name=self.collection_name)
+
+    @classmethod
+    def _document_from_scored_point(
+            cls,
+            scored_point: Any,
+            content_payload_key: str,
+            metadata_payload_key: str,
+    ) -> Document:
+        if scored_point.payload.get('doc_id'):
+            return Document(
+                page_content=scored_point.payload.get(content_payload_key),
+                metadata={'doc_id': scored_point.id}
+            )
+
+        return Document(
+            page_content=scored_point.payload.get(content_payload_key),
+            metadata=scored_point.payload.get(metadata_payload_key) or {},
+        )
+
+    def _reload_if_needed(self):
+        if isinstance(self.client, QdrantLocal):
+            self.client = cast(QdrantLocal, self.client)
+            self.client._load()

+ 0 - 147
api/core/vector_store/qdrant_vector_store_client.py

@@ -1,147 +0,0 @@
-import os
-from typing import cast, List
-
-from llama_index.data_structs import Node
-from llama_index.data_structs.node_v2 import DocumentRelationship
-from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult
-from qdrant_client.http.models import Payload, Filter
-
-import qdrant_client
-from llama_index import ServiceContext, GPTVectorStoreIndex, GPTQdrantIndex
-from llama_index.data_structs.data_structs_v2 import QdrantIndexDict
-from llama_index.vector_stores import QdrantVectorStore
-from qdrant_client.local.qdrant_local import QdrantLocal
-
-from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore
-
-
-class QdrantVectorStoreClient(BaseVectorStoreClient):
-
-    def __init__(self, url: str, api_key: str, root_path: str):
-        self._client = self.init_from_config(url, api_key, root_path)
-
-    @classmethod
-    def init_from_config(cls, url: str, api_key: str, root_path: str):
-        if url and url.startswith('path:'):
-            path = url.replace('path:', '')
-            if not os.path.isabs(path):
-                path = os.path.join(root_path, path)
-
-            return qdrant_client.QdrantClient(
-                path=path
-            )
-        else:
-            return qdrant_client.QdrantClient(
-                url=url,
-                api_key=api_key,
-            )
-
-    def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
-        index_struct = QdrantIndexDict()
-
-        if self._client is None:
-            raise Exception("Vector client is not initialized.")
-
-        # {"collection_name": "Gpt_index_xxx"}
-        collection_name = config.get('collection_name')
-        if not collection_name:
-            raise Exception("collection_name cannot be None.")
-
-        return GPTQdrantEnhanceIndex(
-            service_context=service_context,
-            index_struct=index_struct,
-            vector_store=QdrantEnhanceVectorStore(
-                client=self._client,
-                collection_name=collection_name
-            )
-        )
-
-    def to_index_config(self, index_id: str) -> dict:
-        return {"collection_name": index_id}
-
-
-class GPTQdrantEnhanceIndex(GPTQdrantIndex, BaseGPTVectorStoreIndex):
-    pass
-
-
-class QdrantEnhanceVectorStore(QdrantVectorStore, EnhanceVectorStore):
-    def delete_node(self, node_id: str):
-        """
-        Delete node from the index.
-
-        :param node_id: node id
-        """
-        from qdrant_client.http import models as rest
-
-        self._reload_if_needed()
-
-        self._client.delete(
-            collection_name=self._collection_name,
-            points_selector=rest.Filter(
-                must=[
-                    rest.FieldCondition(
-                        key="id", match=rest.MatchValue(value=node_id)
-                    )
-                ]
-            ),
-        )
-
-    def exists_by_node_id(self, node_id: str) -> bool:
-        """
-        Get node from the index by node id.
-
-        :param node_id: node id
-        """
-        self._reload_if_needed()
-
-        response = self._client.retrieve(
-            collection_name=self._collection_name,
-            ids=[node_id]
-        )
-
-        return len(response) > 0
-
-    def query(
-        self,
-        query: VectorStoreQuery,
-    ) -> VectorStoreQueryResult:
-        """Query index for top k most similar nodes.
-
-        Args:
-            query (VectorStoreQuery): query
-        """
-        query_embedding = cast(List[float], query.query_embedding)
-
-        self._reload_if_needed()
-
-        response = self._client.search(
-            collection_name=self._collection_name,
-            query_vector=query_embedding,
-            limit=cast(int, query.similarity_top_k),
-            query_filter=cast(Filter, self._build_query_filter(query)),
-            with_vectors=True
-        )
-
-        nodes = []
-        similarities = []
-        ids = []
-        for point in response:
-            payload = cast(Payload, point.payload)
-            node = Node(
-                doc_id=str(point.id),
-                text=payload.get("text"),
-                embedding=point.vector,
-                extra_info=payload.get("extra_info"),
-                relationships={
-                    DocumentRelationship.SOURCE: payload.get("doc_id", "None"),
-                },
-            )
-            nodes.append(node)
-            similarities.append(point.score)
-            ids.append(str(point.id))
-
-        return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
-
-    def _reload_if_needed(self):
-        if isinstance(self._client._client, QdrantLocal):
-            self._client._client._load()

+ 0 - 62
api/core/vector_store/vector_store.py

@@ -1,62 +0,0 @@
-from flask import Flask
-from llama_index import ServiceContext, GPTVectorStoreIndex
-from requests import ReadTimeout
-from tenacity import retry, retry_if_exception_type, stop_after_attempt
-
-from core.vector_store.qdrant_vector_store_client import QdrantVectorStoreClient
-from core.vector_store.weaviate_vector_store_client import WeaviateVectorStoreClient
-
-SUPPORTED_VECTOR_STORES = ['weaviate', 'qdrant']
-
-
-class VectorStore:
-
-    def __init__(self):
-        self._vector_store = None
-        self._client = None
-
-    def init_app(self, app: Flask):
-        if not app.config['VECTOR_STORE']:
-            return
-
-        self._vector_store = app.config['VECTOR_STORE']
-        if self._vector_store not in SUPPORTED_VECTOR_STORES:
-            raise ValueError(f"Vector store {self._vector_store} is not supported.")
-
-        if self._vector_store == 'weaviate':
-            self._client = WeaviateVectorStoreClient(
-                endpoint=app.config['WEAVIATE_ENDPOINT'],
-                api_key=app.config['WEAVIATE_API_KEY'],
-                grpc_enabled=app.config['WEAVIATE_GRPC_ENABLED'],
-                batch_size=app.config['WEAVIATE_BATCH_SIZE']
-            )
-        elif self._vector_store == 'qdrant':
-            self._client = QdrantVectorStoreClient(
-                url=app.config['QDRANT_URL'],
-                api_key=app.config['QDRANT_API_KEY'],
-                root_path=app.root_path
-            )
-
-        app.extensions['vector_store'] = self
-
-    @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
-    def get_index(self, service_context: ServiceContext, index_struct: dict) -> GPTVectorStoreIndex:
-        vector_store_config: dict = index_struct.get('vector_store')
-        index = self.get_client().get_index(
-            service_context=service_context,
-            config=vector_store_config
-        )
-
-        return index
-
-    def to_index_struct(self, index_id: str) -> dict:
-        return {
-            "type": self._vector_store,
-            "vector_store": self.get_client().to_index_config(index_id)
-        }
-
-    def get_client(self):
-        if not self._client:
-            raise Exception("Vector store client is not initialized.")
-
-        return self._client

+ 0 - 66
api/core/vector_store/vector_store_index_query.py

@@ -1,66 +0,0 @@
-from llama_index.indices.query.base import IS
-from typing import (
-    Any,
-    Dict,
-    List,
-    Optional
-)
-
-from llama_index.docstore import BaseDocumentStore
-from llama_index.indices.postprocessor.node import (
-    BaseNodePostprocessor,
-)
-from llama_index.indices.vector_store import GPTVectorStoreIndexQuery
-from llama_index.indices.response.response_builder import ResponseMode
-from llama_index.indices.service_context import ServiceContext
-from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
-from llama_index.prompts.prompts import (
-    QuestionAnswerPrompt,
-    RefinePrompt,
-    SimpleInputPrompt,
-)
-
-from core.index.query.synthesizer import EnhanceResponseSynthesizer
-
-
-class EnhanceGPTVectorStoreIndexQuery(GPTVectorStoreIndexQuery):
-    @classmethod
-    def from_args(
-            cls,
-            index_struct: IS,
-            service_context: ServiceContext,
-            docstore: Optional[BaseDocumentStore] = None,
-            node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
-            verbose: bool = False,
-            # response synthesizer args
-            response_mode: ResponseMode = ResponseMode.DEFAULT,
-            text_qa_template: Optional[QuestionAnswerPrompt] = None,
-            refine_template: Optional[RefinePrompt] = None,
-            simple_template: Optional[SimpleInputPrompt] = None,
-            response_kwargs: Optional[Dict] = None,
-            use_async: bool = False,
-            streaming: bool = False,
-            optimizer: Optional[BaseTokenUsageOptimizer] = None,
-            # class-specific args
-            **kwargs: Any,
-    ) -> "BaseGPTIndexQuery":
-        response_synthesizer = EnhanceResponseSynthesizer.from_args(
-            service_context=service_context,
-            text_qa_template=text_qa_template,
-            refine_template=refine_template,
-            simple_template=simple_template,
-            response_mode=response_mode,
-            response_kwargs=response_kwargs,
-            use_async=use_async,
-            streaming=streaming,
-            optimizer=optimizer,
-        )
-        return cls(
-            index_struct=index_struct,
-            service_context=service_context,
-            response_synthesizer=response_synthesizer,
-            docstore=docstore,
-            node_postprocessors=node_postprocessors,
-            verbose=verbose,
-            **kwargs,
-        )

+ 38 - 0
api/core/vector_store/weaviate_vector_store.py

@@ -0,0 +1,38 @@
+from langchain.vectorstores import Weaviate
+
+
+class WeaviateVectorStore(Weaviate):
+    def del_texts(self, where_filter: dict):
+        if not where_filter:
+            raise ValueError('where_filter must not be empty')
+
+        self._client.batch.delete_objects(
+            class_name=self._index_name,
+            where=where_filter,
+            output='minimal'
+        )
+
+    def del_text(self, uuid: str) -> None:
+        self._client.data_object.delete(
+            uuid,
+            class_name=self._index_name
+        )
+
+    def text_exists(self, uuid: str) -> bool:
+        result = self._client.query.get(self._index_name).with_additional(["id"]).with_where({
+            "path": ["doc_id"],
+            "operator": "Equal",
+            "valueText": uuid,
+        }).with_limit(1).do()
+
+        if "errors" in result:
+            raise ValueError(f"Error during query: {result['errors']}")
+
+        entries = result["data"]["Get"][self._index_name]
+        if len(entries) == 0:
+            return False
+
+        return True
+
+    def delete(self):
+        self._client.schema.delete_class(self._index_name)

+ 0 - 270
api/core/vector_store/weaviate_vector_store_client.py

@@ -1,270 +0,0 @@
-import json
-import weaviate
-from dataclasses import field
-from typing import List, Any, Dict, Optional
-
-from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore
-from llama_index import ServiceContext, GPTWeaviateIndex, GPTVectorStoreIndex
-from llama_index.data_structs.data_structs_v2 import WeaviateIndexDict, Node
-from llama_index.data_structs.node_v2 import DocumentRelationship
-from llama_index.readers.weaviate.client import _class_name, NODE_SCHEMA, _logger
-from llama_index.vector_stores import WeaviateVectorStore
-from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult, VectorStoreQueryMode
-from llama_index.readers.weaviate.utils import (
-    parse_get_response,
-    validate_client,
-)
-
-
-class WeaviateVectorStoreClient(BaseVectorStoreClient):
-
-    def __init__(self, endpoint: str, api_key: str, grpc_enabled: bool, batch_size: int):
-        self._client = self.init_from_config(endpoint, api_key, grpc_enabled, batch_size)
-
-    def init_from_config(self, endpoint: str, api_key: str, grpc_enabled: bool, batch_size: int):
-        auth_config = weaviate.auth.AuthApiKey(api_key=api_key)
-
-        weaviate.connect.connection.has_grpc = grpc_enabled
-
-        client = weaviate.Client(
-            url=endpoint,
-            auth_client_secret=auth_config,
-            timeout_config=(5, 60),
-            startup_period=None
-        )
-
-        client.batch.configure(
-            # `batch_size` takes an `int` value to enable auto-batching
-            # (`None` is used for manual batching)
-            batch_size=batch_size,
-            # dynamically update the `batch_size` based on import speed
-            dynamic=True,
-            # `timeout_retries` takes an `int` value to retry on time outs
-            timeout_retries=3,
-        )
-
-        return client
-
-    def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
-        index_struct = WeaviateIndexDict()
-
-        if self._client is None:
-            raise Exception("Vector client is not initialized.")
-
-        # {"class_prefix": "Gpt_index_xxx"}
-        class_prefix = config.get('class_prefix')
-        if not class_prefix:
-            raise Exception("class_prefix cannot be None.")
-
-        return GPTWeaviateEnhanceIndex(
-            service_context=service_context,
-            index_struct=index_struct,
-            vector_store=WeaviateWithSimilaritiesVectorStore(
-                weaviate_client=self._client,
-                class_prefix=class_prefix
-            )
-        )
-
-    def to_index_config(self, index_id: str) -> dict:
-        return {"class_prefix": index_id}
-
-
-class WeaviateWithSimilaritiesVectorStore(WeaviateVectorStore, EnhanceVectorStore):
-    def query(self, query: VectorStoreQuery) -> VectorStoreQueryResult:
-        """Query index for top k most similar nodes."""
-        nodes = self.weaviate_query(
-            self._client,
-            self._class_prefix,
-            query,
-        )
-        nodes = nodes[: query.similarity_top_k]
-        node_idxs = [str(i) for i in range(len(nodes))]
-
-        similarities = []
-        for node in nodes:
-            similarities.append(node.extra_info['similarity'])
-            del node.extra_info['similarity']
-
-        return VectorStoreQueryResult(nodes=nodes, ids=node_idxs, similarities=similarities)
-
-    def weaviate_query(
-            self,
-            client: Any,
-            class_prefix: str,
-            query_spec: VectorStoreQuery,
-    ) -> List[Node]:
-        """Convert to LlamaIndex list."""
-        validate_client(client)
-
-        class_name = _class_name(class_prefix)
-        prop_names = [p["name"] for p in NODE_SCHEMA]
-        vector = query_spec.query_embedding
-
-        # build query
-        query = client.query.get(class_name, prop_names).with_additional(["id", "vector", "certainty"])
-        if query_spec.mode == VectorStoreQueryMode.DEFAULT:
-            _logger.debug("Using vector search")
-            if vector is not None:
-                query = query.with_near_vector(
-                    {
-                        "vector": vector,
-                    }
-                )
-        elif query_spec.mode == VectorStoreQueryMode.HYBRID:
-            _logger.debug(f"Using hybrid search with alpha {query_spec.alpha}")
-            query = query.with_hybrid(
-                query=query_spec.query_str,
-                alpha=query_spec.alpha,
-                vector=vector,
-            )
-        query = query.with_limit(query_spec.similarity_top_k)
-        _logger.debug(f"Using limit of {query_spec.similarity_top_k}")
-
-        # execute query
-        query_result = query.do()
-
-        # parse results
-        parsed_result = parse_get_response(query_result)
-        entries = parsed_result[class_name]
-        results = [self._to_node(entry) for entry in entries]
-        return results
-
-    def _to_node(self, entry: Dict) -> Node:
-        """Convert to Node."""
-        extra_info_str = entry["extra_info"]
-        if extra_info_str == "":
-            extra_info = None
-        else:
-            extra_info = json.loads(extra_info_str)
-
-        if 'certainty' in entry['_additional']:
-            if extra_info:
-                extra_info['similarity'] = entry['_additional']['certainty']
-            else:
-                extra_info = {'similarity': entry['_additional']['certainty']}
-
-        node_info_str = entry["node_info"]
-        if node_info_str == "":
-            node_info = None
-        else:
-            node_info = json.loads(node_info_str)
-
-        relationships_str = entry["relationships"]
-        relationships: Dict[DocumentRelationship, str]
-        if relationships_str == "":
-            relationships = field(default_factory=dict)
-        else:
-            relationships = {
-                DocumentRelationship(k): v for k, v in json.loads(relationships_str).items()
-            }
-
-        return Node(
-            text=entry["text"],
-            doc_id=entry["doc_id"],
-            embedding=entry["_additional"]["vector"],
-            extra_info=extra_info,
-            node_info=node_info,
-            relationships=relationships,
-        )
-
-    def delete(self, doc_id: str, **delete_kwargs: Any) -> None:
-        """Delete a document.
-
-        Args:
-            doc_id (str): document id
-
-        """
-        delete_document(self._client, doc_id, self._class_prefix)
-
-    def delete_node(self, node_id: str):
-        """
-        Delete node from the index.
-
-        :param node_id: node id
-        """
-        delete_node(self._client, node_id, self._class_prefix)
-
-    def exists_by_node_id(self, node_id: str) -> bool:
-        """
-        Get node from the index by node id.
-
-        :param node_id: node id
-        """
-        entry = get_by_node_id(self._client, node_id, self._class_prefix)
-        return True if entry else False
-
-
-class GPTWeaviateEnhanceIndex(GPTWeaviateIndex, BaseGPTVectorStoreIndex):
-    pass
-
-
-def delete_document(client: Any, ref_doc_id: str, class_prefix: str) -> None:
-    """Delete entry."""
-    validate_client(client)
-    # make sure that each entry
-    class_name = _class_name(class_prefix)
-    where_filter = {
-        "path": ["ref_doc_id"],
-        "operator": "Equal",
-        "valueString": ref_doc_id,
-    }
-    query = (
-        client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
-    )
-
-    query_result = query.do()
-    parsed_result = parse_get_response(query_result)
-    entries = parsed_result[class_name]
-    for entry in entries:
-        client.data_object.delete(entry["_additional"]["id"], class_name)
-
-    while len(entries) > 0:
-        query_result = query.do()
-        parsed_result = parse_get_response(query_result)
-        entries = parsed_result[class_name]
-        for entry in entries:
-            client.data_object.delete(entry["_additional"]["id"], class_name)
-
-
-def delete_node(client: Any, node_id: str, class_prefix: str) -> None:
-    """Delete entry."""
-    validate_client(client)
-    # make sure that each entry
-    class_name = _class_name(class_prefix)
-    where_filter = {
-        "path": ["doc_id"],
-        "operator": "Equal",
-        "valueString": node_id,
-    }
-    query = (
-        client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
-    )
-
-    query_result = query.do()
-    parsed_result = parse_get_response(query_result)
-    entries = parsed_result[class_name]
-    for entry in entries:
-        client.data_object.delete(entry["_additional"]["id"], class_name)
-
-
-def get_by_node_id(client: Any, node_id: str, class_prefix: str) -> Optional[Dict]:
-    """Delete entry."""
-    validate_client(client)
-    # make sure that each entry
-    class_name = _class_name(class_prefix)
-    where_filter = {
-        "path": ["doc_id"],
-        "operator": "Equal",
-        "valueString": node_id,
-    }
-    query = (
-        client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
-    )
-
-    query_result = query.do()
-    parsed_result = parse_get_response(query_result)
-    entries = parsed_result[class_name]
-    if len(entries) == 0:
-        return None
-
-    return entries[0]

+ 0 - 7
api/extensions/ext_vector_store.py

@@ -1,7 +0,0 @@
-from core.vector_store.vector_store import VectorStore
-
-vector_store = VectorStore()
-
-
-def init_app(app):
-    vector_store.init_app(app)

+ 6 - 0
api/libs/helper.py

@@ -3,6 +3,7 @@ import re
 import subprocess
 import subprocess
 import uuid
 import uuid
 from datetime import datetime
 from datetime import datetime
+from hashlib import sha256
 from zoneinfo import available_timezones
 from zoneinfo import available_timezones
 import random
 import random
 import string
 import string
@@ -147,3 +148,8 @@ def get_remote_ip(request):
         return request.headers.getlist("X-Forwarded-For")[0]
         return request.headers.getlist("X-Forwarded-For")[0]
     else:
     else:
         return request.remote_addr
         return request.remote_addr
+
+
+def generate_text_hash(text: str) -> str:
+    hash_text = str(text) + 'None'
+    return sha256(hash_text.encode()).hexdigest()

+ 0 - 2
api/models/account.py

@@ -38,8 +38,6 @@ class Account(UserMixin, db.Model):
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
 
 
-    _current_tenant: db.Model = None
-
     @property
     @property
     def current_tenant(self):
     def current_tenant(self):
         return self._current_tenant
         return self._current_tenant

+ 30 - 2
api/models/dataset.py

@@ -66,6 +66,23 @@ class Dataset(db.Model):
     def document_count(self):
     def document_count(self):
         return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar()
         return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar()
 
 
+    @property
+    def available_document_count(self):
+        return db.session.query(func.count(Document.id)).filter(
+            Document.dataset_id == self.id,
+            Document.indexing_status == 'completed',
+            Document.enabled == True,
+            Document.archived == False
+        ).scalar()
+
+    @property
+    def available_segment_count(self):
+        return db.session.query(func.count(DocumentSegment.id)).filter(
+            DocumentSegment.dataset_id == self.id,
+            DocumentSegment.status == 'completed',
+            DocumentSegment.enabled == True
+        ).scalar()
+
     @property
     @property
     def word_count(self):
     def word_count(self):
         return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \
         return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \
@@ -260,7 +277,7 @@ class Document(db.Model):
 
 
     @property
     @property
     def dataset(self):
     def dataset(self):
-        return Dataset.query.get(self.dataset_id)
+        return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).one_or_none()
 
 
     @property
     @property
     def segment_count(self):
     def segment_count(self):
@@ -395,7 +412,18 @@ class DatasetKeywordTable(db.Model):
 
 
     @property
     @property
     def keyword_table_dict(self):
     def keyword_table_dict(self):
-        return json.loads(self.keyword_table) if self.keyword_table else None
+        class SetDecoder(json.JSONDecoder):
+            def __init__(self, *args, **kwargs):
+                super().__init__(object_hook=self.object_hook, *args, **kwargs)
+
+            def object_hook(self, dct):
+                if isinstance(dct, dict):
+                    for keyword, node_idxs in dct.items():
+                        if isinstance(node_idxs, list):
+                            dct[keyword] = set(node_idxs)
+                return dct
+
+        return json.loads(self.keyword_table, cls=SetDecoder) if self.keyword_table else None
 
 
 
 
 class Embedding(db.Model):
 class Embedding(db.Model):

+ 5 - 4
api/requirements.txt

@@ -2,6 +2,7 @@ coverage~=7.2.4
 beautifulsoup4==4.12.2
 beautifulsoup4==4.12.2
 flask~=2.3.2
 flask~=2.3.2
 Flask-SQLAlchemy~=3.0.3
 Flask-SQLAlchemy~=3.0.3
+SQLAlchemy~=1.4.28
 flask-login==0.6.2
 flask-login==0.6.2
 flask-migrate~=4.0.4
 flask-migrate~=4.0.4
 flask-restful==0.3.9
 flask-restful==0.3.9
@@ -9,8 +10,7 @@ flask-session2==1.3.1
 flask-cors==3.0.10
 flask-cors==3.0.10
 gunicorn~=20.1.0
 gunicorn~=20.1.0
 gevent~=22.10.2
 gevent~=22.10.2
-langchain==0.0.142
-llama-index==0.5.27
+langchain==0.0.209
 openai~=0.27.5
 openai~=0.27.5
 psycopg2-binary~=2.9.6
 psycopg2-binary~=2.9.6
 pycryptodome==3.17
 pycryptodome==3.17
@@ -29,6 +29,7 @@ sentry-sdk[flask]~=1.21.1
 jieba==0.42.1
 jieba==0.42.1
 celery==5.2.7
 celery==5.2.7
 redis~=4.5.4
 redis~=4.5.4
-pypdf==3.8.1
 openpyxl==3.1.2
 openpyxl==3.1.2
-chardet~=5.1.0
+chardet~=5.1.0
+docx2txt==0.8
+pypdfium2==4.16.0

+ 0 - 1
api/services/app_model_config_service.py

@@ -4,7 +4,6 @@ import uuid
 from core.constant import llm_constant
 from core.constant import llm_constant
 from models.account import Account
 from models.account import Account
 from services.dataset_service import DatasetService
 from services.dataset_service import DatasetService
-from services.errors.account import NoPermissionError
 
 
 
 
 class AppModelConfigService:
 class AppModelConfigService:

+ 0 - 3
api/services/dataset_service.py

@@ -7,7 +7,6 @@ from typing import Optional, List
 from extensions.ext_redis import redis_client
 from extensions.ext_redis import redis_client
 from flask_login import current_user
 from flask_login import current_user
 
 
-from core.index.index_builder import IndexBuilder
 from events.dataset_event import dataset_was_deleted
 from events.dataset_event import dataset_was_deleted
 from events.document_event import document_was_deleted
 from events.document_event import document_was_deleted
 from extensions.ext_database import db
 from extensions.ext_database import db
@@ -386,8 +385,6 @@ class DocumentService:
 
 
             dataset.indexing_technique = document_data["indexing_technique"]
             dataset.indexing_technique = document_data["indexing_technique"]
 
 
-        if dataset.indexing_technique == 'high_quality':
-            IndexBuilder.get_default_service_context(dataset.tenant_id)
         documents = []
         documents = []
         batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
         batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
         if 'original_document_id' in document_data and document_data["original_document_id"]:
         if 'original_document_id' in document_data and document_data["original_document_id"]:

+ 44 - 36
api/services/hit_testing_service.py

@@ -3,47 +3,56 @@ import time
 from typing import List
 from typing import List
 
 
 import numpy as np
 import numpy as np
-from llama_index.data_structs.node_v2 import NodeWithScore
-from llama_index.indices.query.schema import QueryBundle
-from llama_index.indices.vector_store import GPTVectorStoreIndexQuery
+from flask import current_app
+from langchain.embeddings import OpenAIEmbeddings
+from langchain.embeddings.base import Embeddings
+from langchain.schema import Document
 from sklearn.manifold import TSNE
 from sklearn.manifold import TSNE
 
 
-from core.docstore.empty_docstore import EmptyDocumentStore
-from core.index.vector_index import VectorIndex
+from core.embedding.cached_embedding import CacheEmbedding
+from core.index.vector_index.vector_index import VectorIndex
+from core.llm.llm_builder import LLMBuilder
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.account import Account
 from models.account import Account
 from models.dataset import Dataset, DocumentSegment, DatasetQuery
 from models.dataset import Dataset, DocumentSegment, DatasetQuery
-from services.errors.index import IndexNotInitializedError
 
 
 
 
 class HitTestingService:
 class HitTestingService:
     @classmethod
     @classmethod
     def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict:
     def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict:
-        index = VectorIndex(dataset=dataset).query_index
-
-        if not index:
-            raise IndexNotInitializedError()
-
-        index_query = GPTVectorStoreIndexQuery(
-            index_struct=index.index_struct,
-            service_context=index.service_context,
-            vector_store=index.query_context.get('vector_store'),
-            docstore=EmptyDocumentStore(),
-            response_synthesizer=None,
-            similarity_top_k=limit
-        )
+        if dataset.available_document_count == 0 or dataset.available_document_count == 0:
+            return {
+                "query": {
+                    "content": query,
+                    "tsne_position": {'x': 0, 'y': 0},
+                },
+                "records": []
+            }
 
 
-        query_bundle = QueryBundle(
-            query_str=query,
-            custom_embedding_strs=[query],
+        model_credentials = LLMBuilder.get_model_credentials(
+            tenant_id=dataset.tenant_id,
+            model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
+            model_name='text-embedding-ada-002'
         )
         )
 
 
-        query_bundle.embedding = index.service_context.embed_model.get_agg_embedding_from_queries(
-            query_bundle.embedding_strs
+        embeddings = CacheEmbedding(OpenAIEmbeddings(
+            **model_credentials
+        ))
+
+        vector_index = VectorIndex(
+            dataset=dataset,
+            config=current_app.config,
+            embeddings=embeddings
         )
         )
 
 
         start = time.perf_counter()
         start = time.perf_counter()
-        nodes = index_query.retrieve(query_bundle=query_bundle)
+        documents = vector_index.search(
+            query,
+            search_type='similarity_score_threshold',
+            search_kwargs={
+                'k': 10
+            }
+        )
         end = time.perf_counter()
         end = time.perf_counter()
         logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
         logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
 
 
@@ -58,25 +67,24 @@ class HitTestingService:
         db.session.add(dataset_query)
         db.session.add(dataset_query)
         db.session.commit()
         db.session.commit()
 
 
-        return cls.compact_retrieve_response(dataset, query_bundle, nodes)
+        return cls.compact_retrieve_response(dataset, embeddings, query, documents)
 
 
     @classmethod
     @classmethod
-    def compact_retrieve_response(cls, dataset: Dataset, query_bundle: QueryBundle, nodes: List[NodeWithScore]):
-        embeddings = [
-            query_bundle.embedding
+    def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: List[Document]):
+        text_embeddings = [
+            embeddings.embed_query(query)
         ]
         ]
 
 
-        for node in nodes:
-            embeddings.append(node.node.embedding)
+        text_embeddings.extend(embeddings.embed_documents([document.page_content for document in documents]))
 
 
-        tsne_position_data = cls.get_tsne_positions_from_embeddings(embeddings)
+        tsne_position_data = cls.get_tsne_positions_from_embeddings(text_embeddings)
 
 
         query_position = tsne_position_data.pop(0)
         query_position = tsne_position_data.pop(0)
 
 
         i = 0
         i = 0
         records = []
         records = []
-        for node in nodes:
-            index_node_id = node.node.doc_id
+        for document in documents:
+            index_node_id = document.metadata['doc_id']
 
 
             segment = db.session.query(DocumentSegment).filter(
             segment = db.session.query(DocumentSegment).filter(
                 DocumentSegment.dataset_id == dataset.id,
                 DocumentSegment.dataset_id == dataset.id,
@@ -91,7 +99,7 @@ class HitTestingService:
 
 
             record = {
             record = {
                 "segment": segment,
                 "segment": segment,
-                "score": node.score,
+                "score": document.metadata['score'],
                 "tsne_position": tsne_position_data[i]
                 "tsne_position": tsne_position_data[i]
             }
             }
 
 
@@ -101,7 +109,7 @@ class HitTestingService:
 
 
         return {
         return {
             "query": {
             "query": {
-                "content": query_bundle.query_str,
+                "content": query,
                 "tsne_position": query_position,
                 "tsne_position": query_position,
             },
             },
             "records": records
             "records": records

+ 33 - 48
api/tasks/add_document_to_index_task.py

@@ -4,96 +4,81 @@ import time
 
 
 import click
 import click
 from celery import shared_task
 from celery import shared_task
-from llama_index.data_structs import Node
-from llama_index.data_structs.node_v2 import DocumentRelationship
+from langchain.schema import Document
 from werkzeug.exceptions import NotFound
 from werkzeug.exceptions import NotFound
 
 
-from core.index.keyword_table_index import KeywordTableIndex
-from core.index.vector_index import VectorIndex
+from core.index.index import IndexBuilder
 from extensions.ext_database import db
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from extensions.ext_redis import redis_client
-from models.dataset import DocumentSegment, Document
+from models.dataset import DocumentSegment
+from models.dataset import Document as DatasetDocument
 
 
 
 
 @shared_task
 @shared_task
-def add_document_to_index_task(document_id: str):
+def add_document_to_index_task(dataset_document_id: str):
     """
     """
     Async Add document to index
     Async Add document to index
     :param document_id:
     :param document_id:
 
 
     Usage: add_document_to_index.delay(document_id)
     Usage: add_document_to_index.delay(document_id)
     """
     """
-    logging.info(click.style('Start add document to index: {}'.format(document_id), fg='green'))
+    logging.info(click.style('Start add document to index: {}'.format(dataset_document_id), fg='green'))
     start_at = time.perf_counter()
     start_at = time.perf_counter()
 
 
-    document = db.session.query(Document).filter(Document.id == document_id).first()
-    if not document:
+    dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document_id).first()
+    if not dataset_document:
         raise NotFound('Document not found')
         raise NotFound('Document not found')
 
 
-    if document.indexing_status != 'completed':
+    if dataset_document.indexing_status != 'completed':
         return
         return
 
 
-    indexing_cache_key = 'document_{}_indexing'.format(document.id)
+    indexing_cache_key = 'document_{}_indexing'.format(dataset_document.id)
 
 
     try:
     try:
         segments = db.session.query(DocumentSegment).filter(
         segments = db.session.query(DocumentSegment).filter(
-            DocumentSegment.document_id == document.id,
+            DocumentSegment.document_id == dataset_document.id,
             DocumentSegment.enabled == True
             DocumentSegment.enabled == True
         ) \
         ) \
             .order_by(DocumentSegment.position.asc()).all()
             .order_by(DocumentSegment.position.asc()).all()
 
 
-        nodes = []
-        previous_node = None
+        documents = []
         for segment in segments:
         for segment in segments:
-            relationships = {
-                DocumentRelationship.SOURCE: document.id
-            }
-
-            if previous_node:
-                relationships[DocumentRelationship.PREVIOUS] = previous_node.doc_id
-
-                previous_node.relationships[DocumentRelationship.NEXT] = segment.index_node_id
-
-            node = Node(
-                doc_id=segment.index_node_id,
-                doc_hash=segment.index_node_hash,
-                text=segment.content,
-                extra_info=None,
-                node_info=None,
-                relationships=relationships
+            document = Document(
+                page_content=segment.content,
+                metadata={
+                    "doc_id": segment.index_node_id,
+                    "doc_hash": segment.index_node_hash,
+                    "document_id": segment.document_id,
+                    "dataset_id": segment.dataset_id,
+                }
             )
             )
 
 
-            previous_node = node
+            documents.append(document)
 
 
-            nodes.append(node)
-
-        dataset = document.dataset
+        dataset = dataset_document.dataset
 
 
         if not dataset:
         if not dataset:
             raise Exception('Document has no dataset')
             raise Exception('Document has no dataset')
 
 
-        vector_index = VectorIndex(dataset=dataset)
-        keyword_table_index = KeywordTableIndex(dataset=dataset)
-
         # save vector index
         # save vector index
-        if dataset.indexing_technique == "high_quality":
-            vector_index.add_nodes(
-                nodes=nodes,
-                duplicate_check=True
-            )
+        index = IndexBuilder.get_index(dataset, 'high_quality')
+        if index:
+            index.add_texts(documents)
 
 
         # save keyword index
         # save keyword index
-        keyword_table_index.add_nodes(nodes)
+        index = IndexBuilder.get_index(dataset, 'economy')
+        if index:
+            index.add_texts(documents)
 
 
         end_at = time.perf_counter()
         end_at = time.perf_counter()
         logging.info(
         logging.info(
-            click.style('Document added to index: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
+            click.style('Document added to index: {} latency: {}'.format(dataset_document.id, end_at - start_at), fg='green'))
     except Exception as e:
     except Exception as e:
         logging.exception("add document to index failed")
         logging.exception("add document to index failed")
-        document.enabled = False
-        document.disabled_at = datetime.datetime.utcnow()
-        document.status = 'error'
-        document.error = str(e)
+        dataset_document.enabled = False
+        dataset_document.disabled_at = datetime.datetime.utcnow()
+        dataset_document.status = 'error'
+        dataset_document.error = str(e)
         db.session.commit()
         db.session.commit()
     finally:
     finally:
         redis_client.delete(indexing_cache_key)
         redis_client.delete(indexing_cache_key)

+ 27 - 32
api/tasks/add_segment_to_index_task.py

@@ -4,12 +4,10 @@ import time
 
 
 import click
 import click
 from celery import shared_task
 from celery import shared_task
-from llama_index.data_structs import Node
-from llama_index.data_structs.node_v2 import DocumentRelationship
+from langchain.schema import Document
 from werkzeug.exceptions import NotFound
 from werkzeug.exceptions import NotFound
 
 
-from core.index.keyword_table_index import KeywordTableIndex
-from core.index.vector_index import VectorIndex
+from core.index.index import IndexBuilder
 from extensions.ext_database import db
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from extensions.ext_redis import redis_client
 from models.dataset import DocumentSegment
 from models.dataset import DocumentSegment
@@ -36,44 +34,41 @@ def add_segment_to_index_task(segment_id: str):
     indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
     indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
 
 
     try:
     try:
-        relationships = {
-            DocumentRelationship.SOURCE: segment.document_id,
-        }
-
-        previous_segment = segment.previous_segment
-        if previous_segment:
-            relationships[DocumentRelationship.PREVIOUS] = previous_segment.index_node_id
-
-        next_segment = segment.next_segment
-        if next_segment:
-            relationships[DocumentRelationship.NEXT] = next_segment.index_node_id
-
-        node = Node(
-            doc_id=segment.index_node_id,
-            doc_hash=segment.index_node_hash,
-            text=segment.content,
-            extra_info=None,
-            node_info=None,
-            relationships=relationships
+        document = Document(
+            page_content=segment.content,
+            metadata={
+                "doc_id": segment.index_node_id,
+                "doc_hash": segment.index_node_hash,
+                "document_id": segment.document_id,
+                "dataset_id": segment.dataset_id,
+            }
         )
         )
 
 
         dataset = segment.dataset
         dataset = segment.dataset
 
 
         if not dataset:
         if not dataset:
-            raise Exception('Segment has no dataset')
+            logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan'))
+            return
 
 
-        vector_index = VectorIndex(dataset=dataset)
-        keyword_table_index = KeywordTableIndex(dataset=dataset)
+        dataset_document = segment.document
+
+        if not dataset_document:
+            logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan'))
+            return
+
+        if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed':
+            logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan'))
+            return
 
 
         # save vector index
         # save vector index
-        if dataset.indexing_technique == "high_quality":
-            vector_index.add_nodes(
-                nodes=[node],
-                duplicate_check=True
-            )
+        index = IndexBuilder.get_index(dataset, 'high_quality')
+        if index:
+            index.add_texts([document], duplicate_check=True)
 
 
         # save keyword index
         # save keyword index
-        keyword_table_index.add_nodes([node])
+        index = IndexBuilder.get_index(dataset, 'economy')
+        if index:
+            index.add_texts([document])
 
 
         end_at = time.perf_counter()
         end_at = time.perf_counter()
         logging.info(click.style('Segment added to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green'))
         logging.info(click.style('Segment added to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green'))

+ 13 - 20
api/tasks/clean_dataset_task.py

@@ -4,8 +4,7 @@ import time
 import click
 import click
 from celery import shared_task
 from celery import shared_task
 
 
-from core.index.keyword_table_index import KeywordTableIndex
-from core.index.vector_index import VectorIndex
+from core.index.index import IndexBuilder
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, DatasetQuery, DatasetProcessRule, \
 from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, DatasetQuery, DatasetProcessRule, \
     AppDatasetJoin
     AppDatasetJoin
@@ -33,29 +32,24 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
             index_struct=index_struct
             index_struct=index_struct
         )
         )
 
 
-        vector_index = VectorIndex(dataset=dataset)
-        keyword_table_index = KeywordTableIndex(dataset=dataset)
-
         documents = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
         documents = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
-        index_doc_ids = [document.id for document in documents]
         segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
         segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
-        index_node_ids = [segment.index_node_id for segment in segments]
 
 
-        # delete from vector index
-        if dataset.indexing_technique == "high_quality":
-            for index_doc_id in index_doc_ids:
-                try:
-                    vector_index.del_doc(index_doc_id)
-                except Exception:
-                    logging.exception("Delete doc index failed when dataset deleted.")
-                    continue
+        vector_index = IndexBuilder.get_index(dataset, 'high_quality')
+        kw_index = IndexBuilder.get_index(dataset, 'economy')
 
 
-        # delete from keyword index
-        if index_node_ids:
+        # delete from vector index
+        if vector_index:
             try:
             try:
-                keyword_table_index.del_nodes(index_node_ids)
+                vector_index.delete()
             except Exception:
             except Exception:
-                logging.exception("Delete nodes index failed when dataset deleted.")
+                logging.exception("Delete doc index failed when dataset deleted.")
+
+        # delete from keyword index
+        try:
+            kw_index.delete()
+        except Exception:
+            logging.exception("Delete nodes index failed when dataset deleted.")
 
 
         for document in documents:
         for document in documents:
             db.session.delete(document)
             db.session.delete(document)
@@ -63,7 +57,6 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
         for segment in segments:
         for segment in segments:
             db.session.delete(segment)
             db.session.delete(segment)
 
 
-        db.session.query(DatasetKeywordTable).filter(DatasetKeywordTable.dataset_id == dataset_id).delete()
         db.session.query(DatasetProcessRule).filter(DatasetProcessRule.dataset_id == dataset_id).delete()
         db.session.query(DatasetProcessRule).filter(DatasetProcessRule.dataset_id == dataset_id).delete()
         db.session.query(DatasetQuery).filter(DatasetQuery.dataset_id == dataset_id).delete()
         db.session.query(DatasetQuery).filter(DatasetQuery.dataset_id == dataset_id).delete()
         db.session.query(AppDatasetJoin).filter(AppDatasetJoin.dataset_id == dataset_id).delete()
         db.session.query(AppDatasetJoin).filter(AppDatasetJoin.dataset_id == dataset_id).delete()

+ 7 - 6
api/tasks/clean_document_task.py

@@ -4,8 +4,7 @@ import time
 import click
 import click
 from celery import shared_task
 from celery import shared_task
 
 
-from core.index.keyword_table_index import KeywordTableIndex
-from core.index.vector_index import VectorIndex
+from core.index.index import IndexBuilder
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.dataset import DocumentSegment, Dataset
 from models.dataset import DocumentSegment, Dataset
 
 
@@ -28,21 +27,23 @@ def clean_document_task(document_id: str, dataset_id: str):
         if not dataset:
         if not dataset:
             raise Exception('Document has no dataset')
             raise Exception('Document has no dataset')
 
 
-        vector_index = VectorIndex(dataset=dataset)
-        keyword_table_index = KeywordTableIndex(dataset=dataset)
+        vector_index = IndexBuilder.get_index(dataset, 'high_quality')
+        kw_index = IndexBuilder.get_index(dataset, 'economy')
 
 
         segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
         segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
         index_node_ids = [segment.index_node_id for segment in segments]
         index_node_ids = [segment.index_node_id for segment in segments]
 
 
         # delete from vector index
         # delete from vector index
-        vector_index.del_nodes(index_node_ids)
+        if vector_index:
+            vector_index.delete_by_document_id(document_id)
 
 
         # delete from keyword index
         # delete from keyword index
         if index_node_ids:
         if index_node_ids:
-            keyword_table_index.del_nodes(index_node_ids)
+            kw_index.delete_by_ids(index_node_ids)
 
 
         for segment in segments:
         for segment in segments:
             db.session.delete(segment)
             db.session.delete(segment)
+
         db.session.commit()
         db.session.commit()
         end_at = time.perf_counter()
         end_at = time.perf_counter()
         logging.info(
         logging.info(

+ 7 - 6
api/tasks/clean_notion_document_task.py

@@ -5,8 +5,7 @@ from typing import List
 import click
 import click
 from celery import shared_task
 from celery import shared_task
 
 
-from core.index.keyword_table_index import KeywordTableIndex
-from core.index.vector_index import VectorIndex
+from core.index.index import IndexBuilder
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.dataset import DocumentSegment, Dataset, Document
 from models.dataset import DocumentSegment, Dataset, Document
 
 
@@ -29,22 +28,24 @@ def clean_notion_document_task(document_ids: List[str], dataset_id: str):
         if not dataset:
         if not dataset:
             raise Exception('Document has no dataset')
             raise Exception('Document has no dataset')
 
 
-        vector_index = VectorIndex(dataset=dataset)
-        keyword_table_index = KeywordTableIndex(dataset=dataset)
+        vector_index = IndexBuilder.get_index(dataset, 'high_quality')
+        kw_index = IndexBuilder.get_index(dataset, 'economy')
         for document_id in document_ids:
         for document_id in document_ids:
             document = db.session.query(Document).filter(
             document = db.session.query(Document).filter(
                 Document.id == document_id
                 Document.id == document_id
             ).first()
             ).first()
             db.session.delete(document)
             db.session.delete(document)
+
             segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
             segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
             index_node_ids = [segment.index_node_id for segment in segments]
             index_node_ids = [segment.index_node_id for segment in segments]
 
 
             # delete from vector index
             # delete from vector index
-            vector_index.del_nodes(index_node_ids)
+            if vector_index:
+                vector_index.delete_by_document_id(document_id)
 
 
             # delete from keyword index
             # delete from keyword index
             if index_node_ids:
             if index_node_ids:
-                keyword_table_index.del_nodes(index_node_ids)
+                kw_index.delete_by_ids(index_node_ids)
 
 
             for segment in segments:
             for segment in segments:
                 db.session.delete(segment)
                 db.session.delete(segment)

+ 36 - 36
api/tasks/deal_dataset_vector_index_task.py

@@ -3,10 +3,12 @@ import time
 
 
 import click
 import click
 from celery import shared_task
 from celery import shared_task
-from llama_index.data_structs.node_v2 import DocumentRelationship, Node
-from core.index.vector_index import VectorIndex
+from langchain.schema import Document
+
+from core.index.index import IndexBuilder
 from extensions.ext_database import db
 from extensions.ext_database import db
-from models.dataset import DocumentSegment, Document, Dataset
+from models.dataset import DocumentSegment, Dataset
+from models.dataset import Document as DatasetDocument
 
 
 
 
 @shared_task
 @shared_task
@@ -24,49 +26,47 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
         dataset = Dataset.query.filter_by(
         dataset = Dataset.query.filter_by(
             id=dataset_id
             id=dataset_id
         ).first()
         ).first()
+
         if not dataset:
         if not dataset:
             raise Exception('Dataset not found')
             raise Exception('Dataset not found')
-        documents = Document.query.filter_by(dataset_id=dataset_id).all()
-        if documents:
-            vector_index = VectorIndex(dataset=dataset)
-            for document in documents:
-                # delete from vector index
-                if action == "remove":
-                    vector_index.del_doc(document.id)
-                elif action == "add":
+
+        if action == "remove":
+            index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True)
+            index.delete()
+        elif action == "add":
+            dataset_documents = db.session.query(DatasetDocument).filter(
+                DatasetDocument.dataset_id == dataset_id,
+                DatasetDocument.indexing_status == 'completed',
+                DatasetDocument.enabled == True,
+                DatasetDocument.archived == False,
+            ).all()
+
+            if dataset_documents:
+                # save vector index
+                index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True)
+                for dataset_document in dataset_documents:
+                    # delete from vector index
                     segments = db.session.query(DocumentSegment).filter(
                     segments = db.session.query(DocumentSegment).filter(
-                        DocumentSegment.document_id == document.id,
+                        DocumentSegment.document_id == dataset_document.id,
                         DocumentSegment.enabled == True
                         DocumentSegment.enabled == True
                     ) .order_by(DocumentSegment.position.asc()).all()
                     ) .order_by(DocumentSegment.position.asc()).all()
 
 
-                    nodes = []
-                    previous_node = None
+                    documents = []
                     for segment in segments:
                     for segment in segments:
-                        relationships = {
-                            DocumentRelationship.SOURCE: document.id
-                        }
-
-                        if previous_node:
-                            relationships[DocumentRelationship.PREVIOUS] = previous_node.doc_id
-
-                            previous_node.relationships[DocumentRelationship.NEXT] = segment.index_node_id
-
-                        node = Node(
-                            doc_id=segment.index_node_id,
-                            doc_hash=segment.index_node_hash,
-                            text=segment.content,
-                            extra_info=None,
-                            node_info=None,
-                            relationships=relationships
+                        document = Document(
+                            page_content=segment.content,
+                            metadata={
+                                "doc_id": segment.index_node_id,
+                                "doc_hash": segment.index_node_hash,
+                                "document_id": segment.document_id,
+                                "dataset_id": segment.dataset_id,
+                            }
                         )
                         )
 
 
-                        previous_node = node
-                        nodes.append(node)
+                        documents.append(document)
+
                     # save vector index
                     # save vector index
-                    vector_index.add_nodes(
-                        nodes=nodes,
-                        duplicate_check=True
-                    )
+                    index.add_texts(documents)
 
 
         end_at = time.perf_counter()
         end_at = time.perf_counter()
         logging.info(
         logging.info(

+ 23 - 23
api/tasks/document_indexing_sync_task.py

@@ -6,11 +6,9 @@ import click
 from celery import shared_task
 from celery import shared_task
 from werkzeug.exceptions import NotFound
 from werkzeug.exceptions import NotFound
 
 
-from core.data_source.notion import NotionPageReader
-from core.index.keyword_table_index import KeywordTableIndex
-from core.index.vector_index import VectorIndex
+from core.data_loader.loader.notion import NotionLoader
+from core.index.index import IndexBuilder
 from core.indexing_runner import IndexingRunner, DocumentIsPausedException
 from core.indexing_runner import IndexingRunner, DocumentIsPausedException
-from core.llm.error import ProviderTokenNotInitError
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.dataset import Document, Dataset, DocumentSegment
 from models.dataset import Document, Dataset, DocumentSegment
 from models.source import DataSourceBinding
 from models.source import DataSourceBinding
@@ -43,6 +41,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
             raise ValueError("no notion page found")
             raise ValueError("no notion page found")
         workspace_id = data_source_info['notion_workspace_id']
         workspace_id = data_source_info['notion_workspace_id']
         page_id = data_source_info['notion_page_id']
         page_id = data_source_info['notion_page_id']
+        page_type = data_source_info['type']
         page_edited_time = data_source_info['last_edited_time']
         page_edited_time = data_source_info['last_edited_time']
         data_source_binding = DataSourceBinding.query.filter(
         data_source_binding = DataSourceBinding.query.filter(
             db.and_(
             db.and_(
@@ -54,8 +53,16 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
         ).first()
         ).first()
         if not data_source_binding:
         if not data_source_binding:
             raise ValueError('Data source binding not found.')
             raise ValueError('Data source binding not found.')
-        reader = NotionPageReader(integration_token=data_source_binding.access_token)
-        last_edited_time = reader.get_page_last_edited_time(page_id)
+
+        loader = NotionLoader(
+            notion_access_token=data_source_binding.access_token,
+            notion_workspace_id=workspace_id,
+            notion_obj_id=page_id,
+            notion_page_type=page_type
+        )
+
+        last_edited_time = loader.get_notion_last_edited_time()
+
         # check the page is updated
         # check the page is updated
         if last_edited_time != page_edited_time:
         if last_edited_time != page_edited_time:
             document.indexing_status = 'parsing'
             document.indexing_status = 'parsing'
@@ -68,18 +75,19 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
                 if not dataset:
                 if not dataset:
                     raise Exception('Dataset not found')
                     raise Exception('Dataset not found')
 
 
-                vector_index = VectorIndex(dataset=dataset)
-                keyword_table_index = KeywordTableIndex(dataset=dataset)
+                vector_index = IndexBuilder.get_index(dataset, 'high_quality')
+                kw_index = IndexBuilder.get_index(dataset, 'economy')
 
 
                 segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
                 segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
                 index_node_ids = [segment.index_node_id for segment in segments]
                 index_node_ids = [segment.index_node_id for segment in segments]
 
 
                 # delete from vector index
                 # delete from vector index
-                vector_index.del_nodes(index_node_ids)
+                if vector_index:
+                    vector_index.delete_by_document_id(document_id)
 
 
                 # delete from keyword index
                 # delete from keyword index
                 if index_node_ids:
                 if index_node_ids:
-                    keyword_table_index.del_nodes(index_node_ids)
+                    kw_index.delete_by_ids(index_node_ids)
 
 
                 for segment in segments:
                 for segment in segments:
                     db.session.delete(segment)
                     db.session.delete(segment)
@@ -89,21 +97,13 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
                     click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green'))
                     click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green'))
             except Exception:
             except Exception:
                 logging.exception("Cleaned document when document update data source or process rule failed")
                 logging.exception("Cleaned document when document update data source or process rule failed")
+
             try:
             try:
                 indexing_runner = IndexingRunner()
                 indexing_runner = IndexingRunner()
                 indexing_runner.run([document])
                 indexing_runner.run([document])
                 end_at = time.perf_counter()
                 end_at = time.perf_counter()
                 logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
                 logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
-            except DocumentIsPausedException:
-                logging.info(click.style('Document update paused, document id: {}'.format(document.id), fg='yellow'))
-            except ProviderTokenNotInitError as e:
-                document.indexing_status = 'error'
-                document.error = str(e.description)
-                document.stopped_at = datetime.datetime.utcnow()
-                db.session.commit()
-            except Exception as e:
-                logging.exception("consume update document failed")
-                document.indexing_status = 'error'
-                document.error = str(e)
-                document.stopped_at = datetime.datetime.utcnow()
-                db.session.commit()
+            except DocumentIsPausedException as ex:
+                logging.info(click.style(str(ex), fg='yellow'))
+            except Exception:
+                pass

+ 6 - 16
api/tasks/document_indexing_task.py

@@ -7,7 +7,6 @@ from celery import shared_task
 from werkzeug.exceptions import NotFound
 from werkzeug.exceptions import NotFound
 
 
 from core.indexing_runner import IndexingRunner, DocumentIsPausedException
 from core.indexing_runner import IndexingRunner, DocumentIsPausedException
-from core.llm.error import ProviderTokenNotInitError
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.dataset import Document
 from models.dataset import Document
 
 
@@ -22,9 +21,9 @@ def document_indexing_task(dataset_id: str, document_ids: list):
     Usage: document_indexing_task.delay(dataset_id, document_id)
     Usage: document_indexing_task.delay(dataset_id, document_id)
     """
     """
     documents = []
     documents = []
+    start_at = time.perf_counter()
     for document_id in document_ids:
     for document_id in document_ids:
         logging.info(click.style('Start process document: {}'.format(document_id), fg='green'))
         logging.info(click.style('Start process document: {}'.format(document_id), fg='green'))
-        start_at = time.perf_counter()
 
 
         document = db.session.query(Document).filter(
         document = db.session.query(Document).filter(
             Document.id == document_id,
             Document.id == document_id,
@@ -44,17 +43,8 @@ def document_indexing_task(dataset_id: str, document_ids: list):
         indexing_runner = IndexingRunner()
         indexing_runner = IndexingRunner()
         indexing_runner.run(documents)
         indexing_runner.run(documents)
         end_at = time.perf_counter()
         end_at = time.perf_counter()
-        logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
-    except DocumentIsPausedException:
-        logging.info(click.style('Document paused, document id: {}'.format(document.id), fg='yellow'))
-    except ProviderTokenNotInitError as e:
-        document.indexing_status = 'error'
-        document.error = str(e.description)
-        document.stopped_at = datetime.datetime.utcnow()
-        db.session.commit()
-    except Exception as e:
-        logging.exception("consume document failed")
-        document.indexing_status = 'error'
-        document.error = str(e)
-        document.stopped_at = datetime.datetime.utcnow()
-        db.session.commit()
+        logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green'))
+    except DocumentIsPausedException as ex:
+        logging.info(click.style(str(ex), fg='yellow'))
+    except Exception:
+        pass

+ 11 - 20
api/tasks/document_indexing_update_task.py

@@ -6,10 +6,8 @@ import click
 from celery import shared_task
 from celery import shared_task
 from werkzeug.exceptions import NotFound
 from werkzeug.exceptions import NotFound
 
 
-from core.index.keyword_table_index import KeywordTableIndex
-from core.index.vector_index import VectorIndex
+from core.index.index import IndexBuilder
 from core.indexing_runner import IndexingRunner, DocumentIsPausedException
 from core.indexing_runner import IndexingRunner, DocumentIsPausedException
-from core.llm.error import ProviderTokenNotInitError
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.dataset import Document, Dataset, DocumentSegment
 from models.dataset import Document, Dataset, DocumentSegment
 
 
@@ -44,18 +42,19 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
         if not dataset:
         if not dataset:
             raise Exception('Dataset not found')
             raise Exception('Dataset not found')
 
 
-        vector_index = VectorIndex(dataset=dataset)
-        keyword_table_index = KeywordTableIndex(dataset=dataset)
+        vector_index = IndexBuilder.get_index(dataset, 'high_quality')
+        kw_index = IndexBuilder.get_index(dataset, 'economy')
 
 
         segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
         segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
         index_node_ids = [segment.index_node_id for segment in segments]
         index_node_ids = [segment.index_node_id for segment in segments]
 
 
         # delete from vector index
         # delete from vector index
-        vector_index.del_nodes(index_node_ids)
+        if vector_index:
+            vector_index.delete_by_ids(index_node_ids)
 
 
         # delete from keyword index
         # delete from keyword index
         if index_node_ids:
         if index_node_ids:
-            keyword_table_index.del_nodes(index_node_ids)
+            kw_index.delete_by_ids(index_node_ids)
 
 
         for segment in segments:
         for segment in segments:
             db.session.delete(segment)
             db.session.delete(segment)
@@ -65,21 +64,13 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
             click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green'))
             click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green'))
     except Exception:
     except Exception:
         logging.exception("Cleaned document when document update data source or process rule failed")
         logging.exception("Cleaned document when document update data source or process rule failed")
+
     try:
     try:
         indexing_runner = IndexingRunner()
         indexing_runner = IndexingRunner()
         indexing_runner.run([document])
         indexing_runner.run([document])
         end_at = time.perf_counter()
         end_at = time.perf_counter()
         logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
         logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
-    except DocumentIsPausedException:
-        logging.info(click.style('Document update paused, document id: {}'.format(document.id), fg='yellow'))
-    except ProviderTokenNotInitError as e:
-        document.indexing_status = 'error'
-        document.error = str(e.description)
-        document.stopped_at = datetime.datetime.utcnow()
-        db.session.commit()
-    except Exception as e:
-        logging.exception("consume update document failed")
-        document.indexing_status = 'error'
-        document.error = str(e)
-        document.stopped_at = datetime.datetime.utcnow()
-        db.session.commit()
+    except DocumentIsPausedException as ex:
+        logging.info(click.style(str(ex), fg='yellow'))
+    except Exception:
+        pass

+ 4 - 9
api/tasks/recover_document_indexing_task.py

@@ -1,4 +1,3 @@
-import datetime
 import logging
 import logging
 import time
 import time
 
 
@@ -41,11 +40,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str):
             indexing_runner.run_in_indexing_status(document)
             indexing_runner.run_in_indexing_status(document)
         end_at = time.perf_counter()
         end_at = time.perf_counter()
         logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
         logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
-    except DocumentIsPausedException:
-        logging.info(click.style('Document paused, document id: {}'.format(document.id), fg='yellow'))
-    except Exception as e:
-        logging.exception("consume document failed")
-        document.indexing_status = 'error'
-        document.error = str(e)
-        document.stopped_at = datetime.datetime.utcnow()
-        db.session.commit()
+    except DocumentIsPausedException as ex:
+        logging.info(click.style(str(ex), fg='yellow'))
+    except Exception:
+        pass

+ 5 - 6
api/tasks/remove_document_from_index_task.py

@@ -5,8 +5,7 @@ import click
 from celery import shared_task
 from celery import shared_task
 from werkzeug.exceptions import NotFound
 from werkzeug.exceptions import NotFound
 
 
-from core.index.keyword_table_index import KeywordTableIndex
-from core.index.vector_index import VectorIndex
+from core.index.index import IndexBuilder
 from extensions.ext_database import db
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from extensions.ext_redis import redis_client
 from models.dataset import DocumentSegment, Document
 from models.dataset import DocumentSegment, Document
@@ -38,17 +37,17 @@ def remove_document_from_index_task(document_id: str):
         if not dataset:
         if not dataset:
             raise Exception('Document has no dataset')
             raise Exception('Document has no dataset')
 
 
-        vector_index = VectorIndex(dataset=dataset)
-        keyword_table_index = KeywordTableIndex(dataset=dataset)
+        vector_index = IndexBuilder.get_index(dataset, 'high_quality')
+        kw_index = IndexBuilder.get_index(dataset, 'economy')
 
 
         # delete from vector index
         # delete from vector index
-        vector_index.del_doc(document.id)
+        vector_index.delete_by_document_id(document.id)
 
 
         # delete from keyword index
         # delete from keyword index
         segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).all()
         segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).all()
         index_node_ids = [segment.index_node_id for segment in segments]
         index_node_ids = [segment.index_node_id for segment in segments]
         if index_node_ids:
         if index_node_ids:
-            keyword_table_index.del_nodes(index_node_ids)
+            kw_index.delete_by_ids(index_node_ids)
 
 
         end_at = time.perf_counter()
         end_at = time.perf_counter()
         logging.info(
         logging.info(

+ 18 - 8
api/tasks/remove_segment_from_index_task.py

@@ -5,8 +5,7 @@ import click
 from celery import shared_task
 from celery import shared_task
 from werkzeug.exceptions import NotFound
 from werkzeug.exceptions import NotFound
 
 
-from core.index.keyword_table_index import KeywordTableIndex
-from core.index.vector_index import VectorIndex
+from core.index.index import IndexBuilder
 from extensions.ext_database import db
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from extensions.ext_redis import redis_client
 from models.dataset import DocumentSegment
 from models.dataset import DocumentSegment
@@ -36,17 +35,28 @@ def remove_segment_from_index_task(segment_id: str):
         dataset = segment.dataset
         dataset = segment.dataset
 
 
         if not dataset:
         if not dataset:
-            raise Exception('Segment has no dataset')
+            logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan'))
+            return
 
 
-        vector_index = VectorIndex(dataset=dataset)
-        keyword_table_index = KeywordTableIndex(dataset=dataset)
+        dataset_document = segment.document
+
+        if not dataset_document:
+            logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan'))
+            return
+
+        if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed':
+            logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan'))
+            return
+
+        vector_index = IndexBuilder.get_index(dataset, 'high_quality')
+        kw_index = IndexBuilder.get_index(dataset, 'economy')
 
 
         # delete from vector index
         # delete from vector index
-        if dataset.indexing_technique == "high_quality":
-            vector_index.del_nodes([segment.index_node_id])
+        if vector_index:
+            vector_index.delete_by_ids([segment.index_node_id])
 
 
         # delete from keyword index
         # delete from keyword index
-        keyword_table_index.del_nodes([segment.index_node_id])
+        kw_index.delete_by_ids([segment.index_node_id])
 
 
         end_at = time.perf_counter()
         end_at = time.perf_counter()
         logging.info(click.style('Segment removed from index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green'))
         logging.info(click.style('Segment removed from index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green'))