Explorar el Código

feat: upgrade langchain (#430)

Co-authored-by: jyong <718720800@qq.com>
John Wang hace 1 año
padre
commit
3241e4015b
Se han modificado 91 ficheros con 2689 adiciones y 3139 borrados
  1. 1 2
      api/app.py
  2. 35 0
      api/commands.py
  3. 2 0
      api/config.py
  4. 11 10
      api/controllers/console/datasets/data_source.py
  5. 2 28
      api/controllers/console/datasets/file.py
  6. 7 2
      api/controllers/console/version.py
  7. 0 20
      api/core/__init__.py
  8. 8 11
      api/core/agent/agent_builder.py
  9. 1 29
      api/core/callback_handler/agent_loop_gather_callback_handler.py
  10. 1 50
      api/core/callback_handler/dataset_tool_callback_handler.py
  11. 8 21
      api/core/callback_handler/index_tool_callback_handler.py
  12. 31 85
      api/core/callback_handler/llm_callback_handler.py
  13. 12 70
      api/core/callback_handler/main_chain_gather_callback_handler.py
  14. 36 9
      api/core/callback_handler/std_out_callback_handler.py
  15. 2 4
      api/core/chain/chain_builder.py
  16. 6 4
      api/core/chain/llm_router_chain.py
  17. 10 8
      api/core/chain/main_chain_builder.py
  18. 62 16
      api/core/chain/multi_dataset_router_chain.py
  19. 7 2
      api/core/chain/sensitive_word_avoidance_chain.py
  20. 12 3
      api/core/chain/tool_chain.py
  21. 42 17
      api/core/completion.py
  22. 4 4
      api/core/conversation_message_task.py
  23. 43 0
      api/core/data_loader/file_extractor.py
  24. 67 0
      api/core/data_loader/loader/csv.py
  25. 43 0
      api/core/data_loader/loader/excel.py
  26. 35 0
      api/core/data_loader/loader/html.py
  27. 134 0
      api/core/data_loader/loader/markdown.py
  28. 236 236
      api/core/data_loader/loader/notion.py
  29. 55 0
      api/core/data_loader/loader/pdf.py
  30. 35 45
      api/core/docstore/dataset_docstore.py
  31. 0 51
      api/core/docstore/empty_docstore.py
  32. 72 0
      api/core/embedding/cached_embedding.py
  33. 0 214
      api/core/embedding/openai_embedding.py
  34. 59 0
      api/core/index/base.py
  35. 41 0
      api/core/index/index.py
  36. 0 60
      api/core/index/index_builder.py
  37. 0 159
      api/core/index/keyword_table/jieba_keyword_table.py
  38. 0 135
      api/core/index/keyword_table_index.py
  39. 33 0
      api/core/index/keyword_table_index/jieba_keyword_table_handler.py
  40. 238 0
      api/core/index/keyword_table_index/keyword_table_index.py
  41. 0 0
      api/core/index/keyword_table_index/stopwords.py
  42. 0 79
      api/core/index/query/synthesizer.py
  43. 0 22
      api/core/index/readers/html_parser.py
  44. 0 111
      api/core/index/readers/markdown_parser.py
  45. 0 56
      api/core/index/readers/pdf_parser.py
  46. 0 33
      api/core/index/readers/xlsx_parser.py
  47. 0 136
      api/core/index/vector_index.py
  48. 175 0
      api/core/index/vector_index/base.py
  49. 116 0
      api/core/index/vector_index/qdrant_vector_index.py
  50. 69 0
      api/core/index/vector_index/vector_index.py
  51. 132 0
      api/core/index/vector_index/weaviate_vector_index.py
  52. 262 283
      api/core/indexing_runner.py
  53. 17 14
      api/core/llm/llm_builder.py
  54. 4 1
      api/core/llm/provider/azure_provider.py
  55. 13 50
      api/core/llm/streamable_azure_chat_open_ai.py
  56. 13 6
      api/core/llm/streamable_azure_open_ai.py
  57. 14 48
      api/core/llm/streamable_chat_open_ai.py
  58. 13 5
      api/core/llm/streamable_open_ai.py
  59. 1 1
      api/core/memory/read_only_conversation_token_db_string_buffer_shared_memory.py
  60. 0 19
      api/core/prompt/prompts.py
  61. 0 0
      api/core/spiltter/fixed_text_splitter.py
  62. 87 0
      api/core/tool/dataset_index_tool.py
  63. 0 73
      api/core/tool/dataset_tool_builder.py
  64. 0 43
      api/core/tool/llama_index_tool.py
  65. 0 34
      api/core/vector_store/base.py
  66. 69 0
      api/core/vector_store/qdrant_vector_store.py
  67. 0 147
      api/core/vector_store/qdrant_vector_store_client.py
  68. 0 62
      api/core/vector_store/vector_store.py
  69. 0 66
      api/core/vector_store/vector_store_index_query.py
  70. 38 0
      api/core/vector_store/weaviate_vector_store.py
  71. 0 270
      api/core/vector_store/weaviate_vector_store_client.py
  72. 0 7
      api/extensions/ext_vector_store.py
  73. 6 0
      api/libs/helper.py
  74. 0 2
      api/models/account.py
  75. 30 2
      api/models/dataset.py
  76. 5 4
      api/requirements.txt
  77. 0 1
      api/services/app_model_config_service.py
  78. 0 3
      api/services/dataset_service.py
  79. 44 36
      api/services/hit_testing_service.py
  80. 33 48
      api/tasks/add_document_to_index_task.py
  81. 27 32
      api/tasks/add_segment_to_index_task.py
  82. 13 20
      api/tasks/clean_dataset_task.py
  83. 7 6
      api/tasks/clean_document_task.py
  84. 7 6
      api/tasks/clean_notion_document_task.py
  85. 36 36
      api/tasks/deal_dataset_vector_index_task.py
  86. 23 23
      api/tasks/document_indexing_sync_task.py
  87. 6 16
      api/tasks/document_indexing_task.py
  88. 11 20
      api/tasks/document_indexing_update_task.py
  89. 4 9
      api/tasks/recover_document_indexing_task.py
  90. 5 6
      api/tasks/remove_document_from_index_task.py
  91. 18 8
      api/tasks/remove_segment_from_index_task.py

+ 1 - 2
api/app.py

@@ -14,7 +14,7 @@ from flask import Flask, request, Response, session
 import flask_login
 from flask_cors import CORS
 
-from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_vector_store, ext_migrate, \
+from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
     ext_database, ext_storage
 from extensions.ext_database import db
 from extensions.ext_login import login_manager
@@ -79,7 +79,6 @@ def initialize_extensions(app):
     ext_database.init_app(app)
     ext_migrate.init(app, db)
     ext_redis.init_app(app)
-    ext_vector_store.init_app(app)
     ext_storage.init_app(app)
     ext_celery.init_app(app)
     ext_session.init_app(app)

+ 35 - 0
api/commands.py

@@ -1,15 +1,19 @@
 import datetime
+import logging
 import random
 import string
 
 import click
 from flask import current_app
+from werkzeug.exceptions import NotFound
 
+from core.index.index import IndexBuilder
 from libs.password import password_pattern, valid_password, hash_password
 from libs.helper import email as email_validate
 from extensions.ext_database import db
 from libs.rsa import generate_key_pair
 from models.account import InvitationCode, Tenant
+from models.dataset import Dataset
 from models.model import Account
 import secrets
 import base64
@@ -159,8 +163,39 @@ def generate_upper_string():
     return result
 
 
+@click.command('recreate-all-dataset-indexes', help='Recreate all dataset indexes.')
+def recreate_all_dataset_indexes():
+    click.echo(click.style('Start recreate all dataset indexes.', fg='green'))
+    recreate_count = 0
+
+    page = 1
+    while True:
+        try:
+            datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality')\
+                .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
+        except NotFound:
+            break
+
+        page += 1
+        for dataset in datasets:
+            try:
+                click.echo('Recreating dataset index: {}'.format(dataset.id))
+                index = IndexBuilder.get_index(dataset, 'high_quality')
+                if index and index._is_origin():
+                    index.recreate_dataset(dataset)
+                    recreate_count += 1
+                else:
+                    click.echo('passed.')
+            except Exception as e:
+                click.echo(click.style('Recreate dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
+                continue
+
+    click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green'))
+
+
 def register_commands(app):
     app.cli.add_command(reset_password)
     app.cli.add_command(reset_email)
     app.cli.add_command(generate_invitation_codes)
     app.cli.add_command(reset_encrypt_key_pair)
+    app.cli.add_command(recreate_all_dataset_indexes)

+ 2 - 0
api/config.py

@@ -187,11 +187,13 @@ class Config:
         # For temp use only
         # set default LLM provider, default is 'openai', support `azure_openai`
         self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER')
+
         # notion import setting
         self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
         self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
         self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE')
         self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
+        self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN')
 
 
 class CloudEditionConfig(Config):

+ 11 - 10
api/controllers/console/datasets/data_source.py

@@ -10,11 +10,10 @@ from werkzeug.exceptions import NotFound
 from controllers.console import api
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
-from core.data_source.notion import NotionPageReader
+from core.data_loader.loader.notion import NotionLoader
 from core.indexing_runner import IndexingRunner
 from extensions.ext_database import db
 from libs.helper import TimestampField
-from libs.oauth_data_source import NotionOAuth
 from models.dataset import Document
 from models.source import DataSourceBinding
 from services.dataset_service import DatasetService, DocumentService
@@ -232,15 +231,17 @@ class DataSourceNotionApi(Resource):
         ).first()
         if not data_source_binding:
             raise NotFound('Data source binding not found.')
-        reader = NotionPageReader(integration_token=data_source_binding.access_token)
-        if page_type == 'page':
-            page_content = reader.read_page(page_id)
-        elif page_type == 'database':
-            page_content = reader.query_database_data(page_id)
-        else:
-            page_content = ""
+
+        loader = NotionLoader(
+            notion_access_token=data_source_binding.access_token,
+            notion_workspace_id=workspace_id,
+            notion_obj_id=page_id,
+            notion_page_type=page_type
+        )
+
+        text_docs = loader.load()
         return {
-            'content': page_content
+            'content': "\n".join([doc.page_content for doc in text_docs])
         }, 200
 
     @setup_required

+ 2 - 28
api/controllers/console/datasets/file.py

@@ -17,9 +17,7 @@ from controllers.console.datasets.error import NoFileUploadedError, TooManyFiles
     UnsupportedFileTypeError
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
-from core.index.readers.html_parser import HTMLParser
-from core.index.readers.pdf_parser import PDFParser
-from core.index.readers.xlsx_parser import XLSXParser
+from core.data_loader.file_extractor import FileExtractor
 from extensions.ext_storage import storage
 from libs.helper import TimestampField
 from extensions.ext_database import db
@@ -123,31 +121,7 @@ class FilePreviewApi(Resource):
         if extension not in ALLOWED_EXTENSIONS:
             raise UnsupportedFileTypeError()
 
-        with tempfile.TemporaryDirectory() as temp_dir:
-            suffix = Path(upload_file.key).suffix
-            filepath = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
-            storage.download(upload_file.key, filepath)
-
-            if extension == 'pdf':
-                parser = PDFParser({'upload_file': upload_file})
-                text = parser.parse_file(Path(filepath))
-            elif extension in ['html', 'htm']:
-                # Use BeautifulSoup to extract text
-                parser = HTMLParser()
-                text = parser.parse_file(Path(filepath))
-            elif extension == 'xlsx':
-                parser = XLSXParser()
-                text = parser.parse_file(filepath)
-            else:
-                # ['txt', 'markdown', 'md']
-                with open(filepath, "rb") as fp:
-                    data = fp.read()
-                    encoding = chardet.detect(data)['encoding']
-                    if encoding:
-                        text = data.decode(encoding=encoding).strip() if data else ''
-                    else:
-                        text = data.decode(encoding='utf-8').strip() if data else ''
-
+        text = FileExtractor.load(upload_file, return_text=True)
         text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
         return {'content': text}
 

+ 7 - 2
api/controllers/console/version.py

@@ -32,8 +32,13 @@ class VersionApi(Resource):
                 'current_version': args.get('current_version')
             })
         except Exception as error:
-            logging.exception("Check update error.")
-            raise InternalServerError()
+            logging.warning("Check update version error: {}.".format(str(error)))
+            return {
+                'version': args.get('current_version'),
+                'release_date': '',
+                'release_notes': '',
+                'can_auto_update': False
+            }
 
         content = json.loads(response.content)
         return {

+ 0 - 20
api/core/__init__.py

@@ -3,19 +3,11 @@ from typing import Optional
 
 import langchain
 from flask import Flask
-from jieba.analyse import default_tfidf
-from langchain import set_handler
 from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING
-from llama_index import IndexStructType, QueryMode
-from llama_index.indices.registry import INDEX_STRUT_TYPE_TO_QUERY_MAP
 from pydantic import BaseModel
 
 from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
-from core.index.keyword_table.jieba_keyword_table import GPTJIEBAKeywordTableIndex
-from core.index.keyword_table.stopwords import STOPWORDS
 from core.prompt.prompt_template import OneLineFormatter
-from core.vector_store.vector_store import VectorStore
-from core.vector_store.vector_store_index_query import EnhanceGPTVectorStoreIndexQuery
 
 
 class HostedOpenAICredential(BaseModel):
@@ -32,21 +24,9 @@ hosted_llm_credentials = HostedLLMCredentials()
 def init_app(app: Flask):
     formatter = OneLineFormatter()
     DEFAULT_FORMATTER_MAPPING['f-string'] = formatter.format
-    INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.KEYWORD_TABLE] = GPTJIEBAKeywordTableIndex.get_query_map()
-    INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.WEAVIATE] = {
-        QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery,
-        QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery,
-    }
-    INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.QDRANT] = {
-        QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery,
-        QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery,
-    }
-
-    default_tfidf.stop_words = STOPWORDS
 
     if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
         langchain.verbose = True
-        set_handler(DifyStdOutCallbackHandler())
 
     if app.config.get("OPENAI_API_KEY"):
         hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))

+ 8 - 11
api/core/agent/agent_builder.py

@@ -2,7 +2,7 @@ from typing import Optional
 
 from langchain import LLMChain
 from langchain.agents import ZeroShotAgent, AgentExecutor, ConversationalAgent
-from langchain.callbacks import CallbackManager
+from langchain.callbacks.manager import CallbackManager
 from langchain.memory.chat_memory import BaseChatMemory
 
 from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
@@ -16,23 +16,20 @@ class AgentBuilder:
     def to_agent_chain(cls, tenant_id: str, tools, memory: Optional[BaseChatMemory],
                        dataset_tool_callback_handler: DatasetToolCallbackHandler,
                        agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler):
-        llm_callback_manager = CallbackManager([agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()])
         llm = LLMBuilder.to_llm(
             tenant_id=tenant_id,
             model_name=agent_loop_gather_callback_handler.model_name,
             temperature=0,
             max_tokens=1024,
-            callback_manager=llm_callback_manager
+            callbacks=[agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()]
         )
 
-        tool_callback_manager = CallbackManager([
-            agent_loop_gather_callback_handler,
-            dataset_tool_callback_handler,
-            DifyStdOutCallbackHandler()
-        ])
-
         for tool in tools:
-            tool.callback_manager = tool_callback_manager
+            tool.callbacks = [
+                agent_loop_gather_callback_handler,
+                dataset_tool_callback_handler,
+                DifyStdOutCallbackHandler()
+            ]
 
         prompt = cls.build_agent_prompt_template(
             tools=tools,
@@ -54,7 +51,7 @@ class AgentBuilder:
             tools=tools,
             agent=agent,
             memory=memory,
-            callback_manager=agent_callback_manager,
+            callbacks=agent_callback_manager,
             max_iterations=6,
             early_stopping_method="generate",
             # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit

+ 1 - 29
api/core/callback_handler/agent_loop_gather_callback_handler.py

@@ -12,6 +12,7 @@ from core.conversation_message_task import ConversationMessageTask
 
 class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
     """Callback Handler that prints to std out."""
+    raise_error: bool = True
 
     def __init__(self, model_name, conversation_message_task: ConversationMessageTask) -> None:
         """Initialize callback handler."""
@@ -64,10 +65,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
             self._current_loop.completion = response.generations[0][0].text
             self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
 
-    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
-        """Do nothing."""
-        pass
-
     def on_llm_error(
         self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
     ) -> None:
@@ -75,21 +72,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
         self._agent_loops = []
         self._current_loop = None
 
-    def on_chain_start(
-        self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
-    ) -> None:
-        """Print out that we are entering a chain."""
-        pass
-
-    def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
-        """Print out that we finished a chain."""
-        pass
-
-    def on_chain_error(
-        self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
-    ) -> None:
-        logging.error(error)
-
     def on_tool_start(
         self,
         serialized: Dict[str, Any],
@@ -151,16 +133,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
         self._agent_loops = []
         self._current_loop = None
 
-    def on_text(
-        self,
-        text: str,
-        color: Optional[str] = None,
-        end: str = "",
-        **kwargs: Optional[str],
-    ) -> None:
-        """Run on additional input from chains and agents."""
-        pass
-
     def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
         """Run on agent end."""
         # Final Answer

+ 1 - 50
api/core/callback_handler/dataset_tool_callback_handler.py

@@ -3,7 +3,6 @@ import logging
 from typing import Any, Dict, List, Union, Optional
 
 from langchain.callbacks.base import BaseCallbackHandler
-from langchain.schema import AgentAction, AgentFinish, LLMResult
 
 from core.callback_handler.entity.dataset_query import DatasetQueryObj
 from core.conversation_message_task import ConversationMessageTask
@@ -11,6 +10,7 @@ from core.conversation_message_task import ConversationMessageTask
 
 class DatasetToolCallbackHandler(BaseCallbackHandler):
     """Callback Handler that prints to std out."""
+    raise_error: bool = True
 
     def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
         """Initialize callback handler."""
@@ -66,52 +66,3 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
     ) -> None:
         """Do nothing."""
         logging.error(error)
-
-    def on_chain_start(
-        self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
-    ) -> None:
-        pass
-
-    def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
-        pass
-
-    def on_chain_error(
-        self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
-    ) -> None:
-        pass
-
-    def on_llm_start(
-        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
-    ) -> None:
-        pass
-
-    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
-        pass
-
-    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
-        """Do nothing."""
-        pass
-
-    def on_llm_error(
-        self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
-    ) -> None:
-        logging.error(error)
-
-    def on_agent_action(
-        self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
-    ) -> Any:
-        pass
-
-    def on_text(
-        self,
-        text: str,
-        color: Optional[str] = None,
-        end: str = "",
-        **kwargs: Optional[str],
-    ) -> None:
-        """Run on additional input from chains and agents."""
-        pass
-
-    def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
-        """Run on agent end."""
-        pass

+ 8 - 21
api/core/callback_handler/index_tool_callback_handler.py

@@ -1,39 +1,26 @@
-from llama_index import Response
+from typing import List
+
+from langchain.schema import Document
 
 from extensions.ext_database import db
 from models.dataset import DocumentSegment
 
 
-class IndexToolCallbackHandler:
-
-    def __init__(self) -> None:
-        self._response = None
-
-    @property
-    def response(self) -> Response:
-        return self._response
-
-    def on_tool_end(self, response: Response) -> None:
-        """Handle tool end."""
-        self._response = response
-
-
-class DatasetIndexToolCallbackHandler(IndexToolCallbackHandler):
+class DatasetIndexToolCallbackHandler:
     """Callback handler for dataset tool."""
 
     def __init__(self, dataset_id: str) -> None:
-        super().__init__()
         self.dataset_id = dataset_id
 
-    def on_tool_end(self, response: Response) -> None:
+    def on_tool_end(self, documents: List[Document]) -> None:
         """Handle tool end."""
-        for node in response.source_nodes:
-            index_node_id = node.node.doc_id
+        for document in documents:
+            doc_id = document.metadata['doc_id']
 
             # add hit count to document segment
             db.session.query(DocumentSegment).filter(
                 DocumentSegment.dataset_id == self.dataset_id,
-                DocumentSegment.index_node_id == index_node_id
+                DocumentSegment.index_node_id == doc_id
             ).update(
                 {DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
                 synchronize_session=False

+ 31 - 85
api/core/callback_handler/llm_callback_handler.py

@@ -3,7 +3,7 @@ import time
 from typing import Any, Dict, List, Union, Optional
 
 from langchain.callbacks.base import BaseCallbackHandler
-from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage
+from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage, BaseMessage
 
 from core.callback_handler.entity.llm_message import LLMMessage
 from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
@@ -12,6 +12,7 @@ from core.llm.streamable_open_ai import StreamableOpenAI
 
 
 class LLMCallbackHandler(BaseCallbackHandler):
+    raise_error: bool = True
 
     def __init__(self, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
                  conversation_message_task: ConversationMessageTask):
@@ -25,41 +26,41 @@ class LLMCallbackHandler(BaseCallbackHandler):
         """Whether to call verbose callbacks even if verbose is False."""
         return True
 
+    def on_chat_model_start(
+            self,
+            serialized: Dict[str, Any],
+            messages: List[List[BaseMessage]],
+            **kwargs: Any
+    ) -> Any:
+        self.start_at = time.perf_counter()
+        real_prompts = []
+        for message in messages[0]:
+            if message.type == 'human':
+                role = 'user'
+            elif message.type == 'ai':
+                role = 'assistant'
+            else:
+                role = 'system'
+
+            real_prompts.append({
+                "role": role,
+                "text": message.content
+            })
+
+        self.llm_message.prompt = real_prompts
+        self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages[0])
+
     def on_llm_start(
         self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
     ) -> None:
         self.start_at = time.perf_counter()
 
-        if 'Chat' in serialized['name']:
-            real_prompts = []
-            messages = []
-            for prompt in prompts:
-                role, content = prompt.split(': ', maxsplit=1)
-                if role == 'human':
-                    role = 'user'
-                    message = HumanMessage(content=content)
-                elif role == 'ai':
-                    role = 'assistant'
-                    message = AIMessage(content=content)
-                else:
-                    message = SystemMessage(content=content)
-
-                real_prompt = {
-                    "role": role,
-                    "text": content
-                }
-                real_prompts.append(real_prompt)
-                messages.append(message)
-
-            self.llm_message.prompt = real_prompts
-            self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages)
-        else:
-            self.llm_message.prompt = [{
-                "role": 'user',
-                "text": prompts[0]
-            }]
+        self.llm_message.prompt = [{
+            "role": 'user',
+            "text": prompts[0]
+        }]
 
-            self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0])
+        self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0])
 
     def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
         end_at = time.perf_counter()
@@ -95,58 +96,3 @@ class LLMCallbackHandler(BaseCallbackHandler):
                 self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
         else:
             logging.error(error)
-
-    def on_chain_start(
-            self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
-    ) -> None:
-        pass
-
-    def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
-        pass
-
-    def on_chain_error(
-            self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
-    ) -> None:
-        pass
-
-    def on_tool_start(
-            self,
-            serialized: Dict[str, Any],
-            input_str: str,
-            **kwargs: Any,
-    ) -> None:
-        pass
-
-    def on_agent_action(
-            self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
-    ) -> Any:
-        pass
-
-    def on_tool_end(
-            self,
-            output: str,
-            color: Optional[str] = None,
-            observation_prefix: Optional[str] = None,
-            llm_prefix: Optional[str] = None,
-            **kwargs: Any,
-    ) -> None:
-        pass
-
-    def on_tool_error(
-            self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
-    ) -> None:
-        pass
-
-    def on_text(
-            self,
-            text: str,
-            color: Optional[str] = None,
-            end: str = "",
-            **kwargs: Optional[str],
-    ) -> None:
-        pass
-
-    def on_agent_finish(
-            self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
-    ) -> None:
-        pass

+ 12 - 70
api/core/callback_handler/main_chain_gather_callback_handler.py

@@ -1,10 +1,9 @@
 import logging
 import time
 
-from typing import Any, Dict, List, Union, Optional
+from typing import Any, Dict, Union
 
 from langchain.callbacks.base import BaseCallbackHandler
-from langchain.schema import AgentAction, AgentFinish, LLMResult
 
 from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
 from core.callback_handler.entity.chain_result import ChainResult
@@ -14,6 +13,7 @@ from core.conversation_message_task import ConversationMessageTask
 
 class MainChainGatherCallbackHandler(BaseCallbackHandler):
     """Callback Handler that prints to std out."""
+    raise_error: bool = True
 
     def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
         """Initialize callback handler."""
@@ -50,13 +50,15 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
     ) -> None:
         """Print out that we are entering a chain."""
         if not self._current_chain_result:
-            self._current_chain_result = ChainResult(
-                type=serialized['name'],
-                prompt=inputs,
-                started_at=time.perf_counter()
-            )
-            self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
-            self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message
+            chain_type = serialized['id'][-1]
+            if chain_type:
+                self._current_chain_result = ChainResult(
+                    type=chain_type,
+                    prompt=inputs,
+                    started_at=time.perf_counter()
+                )
+                self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
+                self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message
 
     def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
         """Print out that we finished a chain."""
@@ -74,64 +76,4 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
         self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
     ) -> None:
         logging.error(error)
-        self.clear_chain_results()
-
-    def on_llm_start(
-        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
-    ) -> None:
-        pass
-
-    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
-        pass
-
-    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
-        """Do nothing."""
-        pass
-
-    def on_llm_error(
-        self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
-    ) -> None:
-        logging.error(error)
-
-    def on_tool_start(
-        self,
-        serialized: Dict[str, Any],
-        input_str: str,
-        **kwargs: Any,
-    ) -> None:
-        pass
-
-    def on_agent_action(
-        self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
-    ) -> Any:
-        pass
-
-    def on_tool_end(
-        self,
-        output: str,
-        color: Optional[str] = None,
-        observation_prefix: Optional[str] = None,
-        llm_prefix: Optional[str] = None,
-        **kwargs: Any,
-    ) -> None:
-        pass
-
-    def on_tool_error(
-        self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
-    ) -> None:
-        """Do nothing."""
-        logging.error(error)
-
-    def on_text(
-        self,
-        text: str,
-        color: Optional[str] = None,
-        end: str = "",
-        **kwargs: Optional[str],
-    ) -> None:
-        """Run on additional input from chains and agents."""
-        pass
-
-    def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
-        """Run on agent end."""
-        pass
+        self.clear_chain_results()

+ 36 - 9
api/core/callback_handler/std_out_callback_handler.py

@@ -1,9 +1,10 @@
+import os
 import sys
 from typing import Any, Dict, List, Optional, Union
 
 from langchain.callbacks.base import BaseCallbackHandler
 from langchain.input import print_text
-from langchain.schema import AgentAction, AgentFinish, LLMResult
+from langchain.schema import AgentAction, AgentFinish, LLMResult, BaseMessage
 
 
 class DifyStdOutCallbackHandler(BaseCallbackHandler):
@@ -13,17 +14,23 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
         """Initialize callback handler."""
         self.color = color
 
+    def on_chat_model_start(
+            self,
+            serialized: Dict[str, Any],
+            messages: List[List[BaseMessage]],
+            **kwargs: Any
+    ) -> Any:
+        print_text("\n[on_chat_model_start]\n", color='blue')
+        for sub_messages in messages:
+            for sub_message in sub_messages:
+                print_text(str(sub_message) + "\n", color='blue')
+
     def on_llm_start(
         self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
     ) -> None:
         """Print out the prompts."""
         print_text("\n[on_llm_start]\n", color='blue')
-
-        if 'Chat' in serialized['name']:
-            for prompt in prompts:
-                print_text(prompt + "\n", color='blue')
-        else:
-            print_text(prompts[0] + "\n", color='blue')
+        print_text(prompts[0] + "\n", color='blue')
 
     def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
         """Do nothing."""
@@ -44,8 +51,8 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
         self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
     ) -> None:
         """Print out that we are entering a chain."""
-        class_name = serialized["name"]
-        print_text("\n[on_chain_start]\nChain: " + class_name + "\nInputs: " + str(inputs) + "\n", color='pink')
+        chain_type = serialized['id'][-1]
+        print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink')
 
     def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
         """Print out that we finished a chain."""
@@ -117,6 +124,26 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
         """Run on agent end."""
         print_text("[on_agent_finish] " + finish.return_values['output'] + "\n", color='green', end="\n")
 
+    @property
+    def ignore_llm(self) -> bool:
+        """Whether to ignore LLM callbacks."""
+        return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
+
+    @property
+    def ignore_chain(self) -> bool:
+        """Whether to ignore chain callbacks."""
+        return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
+
+    @property
+    def ignore_agent(self) -> bool:
+        """Whether to ignore agent callbacks."""
+        return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
+
+    @property
+    def ignore_chat_model(self) -> bool:
+        """Whether to ignore chat model callbacks."""
+        return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
+
 
 class DifyStreamingStdOutCallbackHandler(DifyStdOutCallbackHandler):
     """Callback handler for streaming. Only works with LLMs that support streaming."""

+ 2 - 4
api/core/chain/chain_builder.py

@@ -1,7 +1,5 @@
 from typing import Optional
 
-from langchain.callbacks import CallbackManager
-
 from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
 from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
 from core.chain.tool_chain import ToolChain
@@ -14,7 +12,7 @@ class ChainBuilder:
             tool=tool,
             input_key=kwargs.get('input_key', 'input'),
             output_key=kwargs.get('output_key', 'tool_output'),
-            callback_manager=CallbackManager([DifyStdOutCallbackHandler()])
+            callbacks=[DifyStdOutCallbackHandler()]
         )
 
     @classmethod
@@ -27,7 +25,7 @@ class ChainBuilder:
                 sensitive_words=sensitive_words.split(","),
                 canned_response=tool_config.get("canned_response", ''),
                 output_key="sensitive_word_avoidance_output",
-                callback_manager=CallbackManager([DifyStdOutCallbackHandler()]),
+                callbacks=[DifyStdOutCallbackHandler()],
                 **kwargs
             )
 

+ 6 - 4
api/core/chain/llm_router_chain.py

@@ -1,15 +1,16 @@
 """Base classes for LLM-powered router chains."""
 from __future__ import annotations
 
-import json
 from typing import Any, Dict, List, Optional, Type, cast, NamedTuple
 
+from langchain.base_language import BaseLanguageModel
+from langchain.callbacks.manager import CallbackManagerForChainRun
 from langchain.chains.base import Chain
 from pydantic import root_validator
 
 from langchain.chains import LLMChain
 from langchain.prompts import BasePromptTemplate
-from langchain.schema import BaseOutputParser, OutputParserException, BaseLanguageModel
+from langchain.schema import BaseOutputParser, OutputParserException
 
 from libs.json_in_md_parser import parse_and_check_json_markdown
 
@@ -51,8 +52,9 @@ class LLMRouterChain(Chain):
             raise ValueError
 
     def _call(
-        self,
-        inputs: Dict[str, Any]
+            self,
+            inputs: Dict[str, Any],
+            run_manager: Optional[CallbackManagerForChainRun] = None,
     ) -> Dict[str, Any]:
         output = cast(
             Dict[str, Any],

+ 10 - 8
api/core/chain/main_chain_builder.py

@@ -1,11 +1,9 @@
-from typing import Optional, List
+from typing import Optional, List, cast
 
-from langchain.callbacks import SharedCallbackManager, CallbackManager
 from langchain.chains import SequentialChain
 from langchain.chains.base import Chain
 from langchain.memory.chat_memory import BaseChatMemory
 
-from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
 from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
 from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
 from core.chain.chain_builder import ChainBuilder
@@ -18,6 +16,7 @@ from models.dataset import Dataset
 class MainChainBuilder:
     @classmethod
     def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
+                                rest_tokens: int,
                                 conversation_message_task: ConversationMessageTask):
         first_input_key = "input"
         final_output_key = "output"
@@ -30,6 +29,7 @@ class MainChainBuilder:
         tool_chains, chains_output_key = cls.get_agent_chains(
             tenant_id=tenant_id,
             agent_mode=agent_mode,
+            rest_tokens=rest_tokens,
             memory=memory,
             conversation_message_task=conversation_message_task
         )
@@ -42,9 +42,8 @@ class MainChainBuilder:
             return None
 
         for chain in chains:
-            # do not add handler into singleton callback manager
-            if not isinstance(chain.callback_manager, SharedCallbackManager):
-                chain.callback_manager.add_handler(chain_callback_handler)
+            chain = cast(Chain, chain)
+            chain.callbacks.append(chain_callback_handler)
 
         # build main chain
         overall_chain = SequentialChain(
@@ -57,7 +56,9 @@ class MainChainBuilder:
         return overall_chain
 
     @classmethod
-    def get_agent_chains(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
+    def get_agent_chains(cls, tenant_id: str, agent_mode: dict,
+                         rest_tokens: int,
+                         memory: Optional[BaseChatMemory],
                          conversation_message_task: ConversationMessageTask):
         # agent mode
         chains = []
@@ -93,7 +94,8 @@ class MainChainBuilder:
                     tenant_id=tenant_id,
                     datasets=datasets,
                     conversation_message_task=conversation_message_task,
-                    callback_manager=CallbackManager([DifyStdOutCallbackHandler()])
+                    rest_tokens=rest_tokens,
+                    callbacks=[DifyStdOutCallbackHandler()]
                 )
                 chains.append(multi_dataset_router_chain)
 

+ 62 - 16
api/core/chain/multi_dataset_router_chain.py

@@ -1,9 +1,9 @@
+import math
 from typing import Mapping, List, Dict, Any, Optional
 
-from langchain import LLMChain, PromptTemplate, ConversationChain
-from langchain.callbacks import CallbackManager
+from langchain import PromptTemplate
+from langchain.callbacks.manager import CallbackManagerForChainRun
 from langchain.chains.base import Chain
-from langchain.schema import BaseLanguageModel
 from pydantic import Extra
 
 from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
@@ -11,10 +11,11 @@ from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHan
 from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser
 from core.conversation_message_task import ConversationMessageTask
 from core.llm.llm_builder import LLMBuilder
-from core.tool.dataset_tool_builder import DatasetToolBuilder
-from core.tool.llama_index_tool import EnhanceLlamaIndexTool
-from models.dataset import Dataset
+from core.tool.dataset_index_tool import DatasetTool
+from models.dataset import Dataset, DatasetProcessRule
 
+DEFAULT_K = 2
+CONTEXT_TOKENS_PERCENT = 0.3
 MULTI_PROMPT_ROUTER_TEMPLATE = """
 Given a raw text input to a language model select the model prompt best suited for \
 the input. You will be given the names of the available prompts and a description of \
@@ -52,7 +53,7 @@ class MultiDatasetRouterChain(Chain):
 
     router_chain: LLMRouterChain
     """Chain for deciding a destination chain and the input to it."""
-    dataset_tools: Mapping[str, EnhanceLlamaIndexTool]
+    dataset_tools: Mapping[str, DatasetTool]
     """Map of name to candidate chains that inputs can be routed to."""
 
     class Config:
@@ -79,41 +80,56 @@ class MultiDatasetRouterChain(Chain):
             tenant_id: str,
             datasets: List[Dataset],
             conversation_message_task: ConversationMessageTask,
+            rest_tokens: int,
             **kwargs: Any,
     ):
         """Convenience constructor for instantiating from destination prompts."""
-        llm_callback_manager = CallbackManager([DifyStdOutCallbackHandler()])
         llm = LLMBuilder.to_llm(
             tenant_id=tenant_id,
             model_name='gpt-3.5-turbo',
             temperature=0,
             max_tokens=1024,
-            callback_manager=llm_callback_manager
+            callbacks=[DifyStdOutCallbackHandler()]
         )
 
-        destinations = ["{}: {}".format(d.id, d.description.replace('\n', ' ') if d.description
+        destinations = ["[[{}]]: {}".format(d.id, d.description.replace('\n', ' ') if d.description
                         else ('useful for when you want to answer queries about the ' + d.name))
                         for d in datasets]
         destinations_str = "\n".join(destinations)
         router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
             destinations=destinations_str
         )
+
         router_prompt = PromptTemplate(
             template=router_template,
             input_variables=["input"],
             output_parser=RouterOutputParser(),
         )
+
         router_chain = LLMRouterChain.from_llm(llm, router_prompt)
         dataset_tools = {}
         for dataset in datasets:
-            dataset_tool = DatasetToolBuilder.build_dataset_tool(
+            # fulfill description when it is empty
+            if dataset.available_document_count == 0 or dataset.available_document_count == 0:
+                continue
+
+            description = dataset.description
+            if not description:
+                description = 'useful for when you want to answer queries about the ' + dataset.name
+
+            k = cls._dynamic_calc_retrieve_k(dataset, rest_tokens)
+            if k == 0:
+                continue
+
+            dataset_tool = DatasetTool(
+                name=f"dataset-{dataset.id}",
+                description=description,
+                k=k,
                 dataset=dataset,
-                response_mode='no_synthesizer',  # "compact"
-                callback_handler=DatasetToolCallbackHandler(conversation_message_task)
+                callbacks=[DatasetToolCallbackHandler(conversation_message_task), DifyStdOutCallbackHandler()]
             )
 
-            if dataset_tool:
-                dataset_tools[dataset.id] = dataset_tool
+            dataset_tools[str(dataset.id)] = dataset_tool
 
         return cls(
             router_chain=router_chain,
@@ -121,9 +137,39 @@ class MultiDatasetRouterChain(Chain):
             **kwargs,
         )
 
+    @classmethod
+    def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
+        processing_rule = dataset.latest_process_rule
+        if not processing_rule:
+            return DEFAULT_K
+
+        if processing_rule.mode == "custom":
+            rules = processing_rule.rules_dict
+            if not rules:
+                return DEFAULT_K
+
+            segmentation = rules["segmentation"]
+            segment_max_tokens = segmentation["max_tokens"]
+        else:
+            segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
+
+        # when rest_tokens is less than default context tokens
+        if rest_tokens < segment_max_tokens * DEFAULT_K:
+            return rest_tokens // segment_max_tokens
+
+        context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
+
+        # when context_limit_tokens is less than default context tokens, use default_k
+        if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
+            return DEFAULT_K
+
+        # Expand the k value when there's still some room left in the 30% rest tokens space
+        return context_limit_tokens // segment_max_tokens
+
     def _call(
         self,
-        inputs: Dict[str, Any]
+        inputs: Dict[str, Any],
+        run_manager: Optional[CallbackManagerForChainRun] = None,
     ) -> Dict[str, Any]:
         if len(self.dataset_tools) == 0:
             return {"text": ''}

+ 7 - 2
api/core/chain/sensitive_word_avoidance_chain.py

@@ -1,5 +1,6 @@
-from typing import List, Dict
+from typing import List, Dict, Optional, Any
 
+from langchain.callbacks.manager import CallbackManagerForChainRun
 from langchain.chains.base import Chain
 
 
@@ -36,7 +37,11 @@ class SensitiveWordAvoidanceChain(Chain):
                 return self.canned_response
         return text
 
-    def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
+    def _call(
+            self,
+            inputs: Dict[str, Any],
+            run_manager: Optional[CallbackManagerForChainRun] = None,
+    ) -> Dict[str, Any]:
         text = inputs[self.input_key]
         output = self._check_sensitive_word(text)
         return {self.output_key: output}

+ 12 - 3
api/core/chain/tool_chain.py

@@ -1,5 +1,6 @@
-from typing import List, Dict
+from typing import List, Dict, Optional, Any
 
+from langchain.callbacks.manager import CallbackManagerForChainRun, AsyncCallbackManagerForChainRun
 from langchain.chains.base import Chain
 from langchain.tools import BaseTool
 
@@ -30,12 +31,20 @@ class ToolChain(Chain):
         """
         return [self.output_key]
 
-    def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
+    def _call(
+            self,
+            inputs: Dict[str, Any],
+            run_manager: Optional[CallbackManagerForChainRun] = None,
+    ) -> Dict[str, Any]:
         input = inputs[self.input_key]
         output = self.tool.run(input, self.verbose)
         return {self.output_key: output}
 
-    async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
+    async def _acall(
+            self,
+            inputs: Dict[str, Any],
+            run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
+    ) -> Dict[str, Any]:
         """Run the logic of this chain and return the output."""
         input = inputs[self.input_key]
         output = await self.tool.arun(input, self.verbose)

+ 42 - 17
api/core/completion.py

@@ -1,17 +1,18 @@
 import logging
 from typing import Optional, List, Union, Tuple
 
-from langchain.callbacks import CallbackManager
+from langchain.base_language import BaseLanguageModel
+from langchain.callbacks.base import BaseCallbackHandler
 from langchain.chat_models.base import BaseChatModel
 from langchain.llms import BaseLLM
-from langchain.schema import BaseMessage, BaseLanguageModel, HumanMessage
+from langchain.schema import BaseMessage, HumanMessage
 from requests.exceptions import ChunkedEncodingError
 
 from core.constant import llm_constant
 from core.callback_handler.llm_callback_handler import LLMCallbackHandler
 from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
     DifyStdOutCallbackHandler
-from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, PubHandler
+from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
 from core.llm.error import LLMBadRequestError
 from core.llm.llm_builder import LLMBuilder
 from core.chain.main_chain_builder import MainChainBuilder
@@ -34,8 +35,6 @@ class Completion:
         """
         errors: ProviderTokenNotInitError
         """
-        cls.validate_query_tokens(app.tenant_id, app_model_config, query)
-
         memory = None
         if conversation:
             # get memory of conversation (read-only)
@@ -48,6 +47,14 @@ class Completion:
 
             inputs = conversation.inputs
 
+        rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
+            mode=app.mode,
+            tenant_id=app.tenant_id,
+            app_model_config=app_model_config,
+            query=query,
+            inputs=inputs
+        )
+
         conversation_message_task = ConversationMessageTask(
             task_id=task_id,
             app=app,
@@ -64,6 +71,7 @@ class Completion:
         main_chain = MainChainBuilder.to_langchain_components(
             tenant_id=app.tenant_id,
             agent_mode=app_model_config.agent_mode_dict,
+            rest_tokens=rest_tokens_for_context_and_memory,
             memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None,
             conversation_message_task=conversation_message_task
         )
@@ -115,7 +123,7 @@ class Completion:
             memory=memory
         )
 
-        final_llm.callback_manager = cls.get_llm_callback_manager(final_llm, streaming, conversation_message_task)
+        final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
 
         cls.recale_llm_max_tokens(
             final_llm=final_llm,
@@ -247,16 +255,14 @@ And answer according to the language of the user's question.
             return messages, ['\nHuman:']
 
     @classmethod
-    def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
-                                 streaming: bool,
-                                 conversation_message_task: ConversationMessageTask) -> CallbackManager:
+    def get_llm_callbacks(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
+                          streaming: bool,
+                          conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]:
         llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
         if streaming:
-            callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
+            return [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
         else:
-            callback_handlers = [llm_callback_handler, DifyStdOutCallbackHandler()]
-
-        return CallbackManager(callback_handlers)
+            return [llm_callback_handler, DifyStdOutCallbackHandler()]
 
     @classmethod
     def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
@@ -293,7 +299,8 @@ And answer according to the language of the user's question.
         return memory
 
     @classmethod
-    def validate_query_tokens(cls, tenant_id: str, app_model_config: AppModelConfig, query: str):
+    def get_validate_rest_tokens(cls, mode: str, tenant_id: str, app_model_config: AppModelConfig,
+                                 query: str, inputs: dict) -> int:
         llm = LLMBuilder.to_llm_from_model(
             tenant_id=tenant_id,
             model=app_model_config.model_dict
@@ -302,8 +309,26 @@ And answer according to the language of the user's question.
         model_limited_tokens = llm_constant.max_context_token_length[llm.model_name]
         max_tokens = llm.max_tokens
 
-        if model_limited_tokens - max_tokens - llm.get_num_tokens(query) < 0:
-            raise LLMBadRequestError("Query is too long")
+        # get prompt without memory and context
+        prompt, _ = cls.get_main_llm_prompt(
+            mode=mode,
+            llm=llm,
+            pre_prompt=app_model_config.pre_prompt,
+            query=query,
+            inputs=inputs,
+            chain_output=None,
+            memory=None
+        )
+
+        prompt_tokens = llm.get_num_tokens(prompt) if isinstance(prompt, str) \
+            else llm.get_num_tokens_from_messages(prompt)
+
+        rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
+        if rest_tokens < 0:
+            raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
+                                     "or shrink the max token, or switch to a llm with a larger token limit size.")
+
+        return rest_tokens
 
     @classmethod
     def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI],
@@ -360,7 +385,7 @@ And answer according to the language of the user's question.
             streaming=streaming
         )
 
-        llm.callback_manager = cls.get_llm_callback_manager(llm, streaming, conversation_message_task)
+        llm.callbacks = cls.get_llm_callbacks(llm, streaming, conversation_message_task)
 
         cls.recale_llm_max_tokens(
             final_llm=llm,

+ 4 - 4
api/core/conversation_message_task.py

@@ -293,12 +293,12 @@ class PubHandler:
         if not user:
             raise ValueError("user is required")
 
-        user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id
+        user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
         return "generate_result:{}-{}".format(user_str, task_id)
 
     @classmethod
     def generate_stopped_cache_key(cls, user: Union[Account | EndUser], task_id: str):
-        user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id
+        user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
         return "generate_result_stopped:{}-{}".format(user_str, task_id)
 
     def pub_text(self, text: str):
@@ -306,10 +306,10 @@ class PubHandler:
             'event': 'message',
             'data': {
                 'task_id': self._task_id,
-                'message_id': self._message.id,
+                'message_id': str(self._message.id),
                 'text': text,
                 'mode': self._conversation.mode,
-                'conversation_id': self._conversation.id
+                'conversation_id': str(self._conversation.id)
             }
         }
 

+ 43 - 0
api/core/data_loader/file_extractor.py

@@ -0,0 +1,43 @@
+import tempfile
+from pathlib import Path
+from typing import List, Union
+
+from langchain.document_loaders import TextLoader, Docx2txtLoader
+from langchain.schema import Document
+
+from core.data_loader.loader.csv import CSVLoader
+from core.data_loader.loader.excel import ExcelLoader
+from core.data_loader.loader.html import HTMLLoader
+from core.data_loader.loader.markdown import MarkdownLoader
+from core.data_loader.loader.pdf import PdfLoader
+from extensions.ext_storage import storage
+from models.model import UploadFile
+
+
+class FileExtractor:
+    @classmethod
+    def load(cls, upload_file: UploadFile, return_text: bool = False) -> Union[List[Document] | str]:
+        with tempfile.TemporaryDirectory() as temp_dir:
+            suffix = Path(upload_file.key).suffix
+            file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
+            storage.download(upload_file.key, file_path)
+
+            input_file = Path(file_path)
+            delimiter = '\n'
+            if input_file.suffix == '.xlsx':
+                loader = ExcelLoader(file_path)
+            elif input_file.suffix == '.pdf':
+                loader = PdfLoader(file_path, upload_file=upload_file)
+            elif input_file.suffix in ['.md', '.markdown']:
+                loader = MarkdownLoader(file_path, autodetect_encoding=True)
+            elif input_file.suffix in ['.htm', '.html']:
+                loader = HTMLLoader(file_path)
+            elif input_file.suffix == '.docx':
+                loader = Docx2txtLoader(file_path)
+            elif input_file.suffix == '.csv':
+                loader = CSVLoader(file_path, autodetect_encoding=True)
+            else:
+                # txt
+                loader = TextLoader(file_path, autodetect_encoding=True)
+
+            return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load()

+ 67 - 0
api/core/data_loader/loader/csv.py

@@ -0,0 +1,67 @@
+import logging
+from typing import Optional, Dict, List
+
+from langchain.document_loaders import CSVLoader as LCCSVLoader
+from langchain.document_loaders.helpers import detect_file_encodings
+
+from models.dataset import Document
+
+logger = logging.getLogger(__name__)
+
+
+class CSVLoader(LCCSVLoader):
+    def __init__(
+            self,
+            file_path: str,
+            source_column: Optional[str] = None,
+            csv_args: Optional[Dict] = None,
+            encoding: Optional[str] = None,
+            autodetect_encoding: bool = True,
+    ):
+        self.file_path = file_path
+        self.source_column = source_column
+        self.encoding = encoding
+        self.csv_args = csv_args or {}
+        self.autodetect_encoding = autodetect_encoding
+
+    def load(self) -> List[Document]:
+        """Load data into document objects."""
+        try:
+            with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
+                docs = self._read_from_file(csvfile)
+        except UnicodeDecodeError as e:
+            if self.autodetect_encoding:
+                detected_encodings = detect_file_encodings(self.file_path)
+                for encoding in detected_encodings:
+                    logger.debug("Trying encoding: ", encoding.encoding)
+                    try:
+                        with open(self.file_path, newline="", encoding=encoding.encoding) as csvfile:
+                            docs = self._read_from_file(csvfile)
+                        break
+                    except UnicodeDecodeError:
+                        continue
+            else:
+                raise RuntimeError(f"Error loading {self.file_path}") from e
+
+        return docs
+
+    def _read_from_file(self, csvfile):
+        docs = []
+        csv_reader = csv.DictReader(csvfile, **self.csv_args)  # type: ignore
+        for i, row in enumerate(csv_reader):
+            content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
+            try:
+                source = (
+                    row[self.source_column]
+                    if self.source_column is not None
+                    else ''
+                )
+            except KeyError:
+                raise ValueError(
+                    f"Source column '{self.source_column}' not found in CSV file."
+                )
+            metadata = {"source": source, "row": i}
+            doc = Document(page_content=content, metadata=metadata)
+            docs.append(doc)
+
+        return docs

+ 43 - 0
api/core/data_loader/loader/excel.py

@@ -0,0 +1,43 @@
+import json
+import logging
+from typing import List
+
+from langchain.document_loaders.base import BaseLoader
+from langchain.schema import Document
+from openpyxl.reader.excel import load_workbook
+
+logger = logging.getLogger(__name__)
+
+
+class ExcelLoader(BaseLoader):
+    """Load xlxs files.
+
+
+    Args:
+        file_path: Path to the file to load.
+    """
+
+    def __init__(
+        self,
+        file_path: str
+    ):
+        """Initialize with file path."""
+        self._file_path = file_path
+
+    def load(self) -> List[Document]:
+        data = []
+        keys = []
+        wb = load_workbook(filename=self._file_path, read_only=True)
+        # loop over all sheets
+        for sheet in wb:
+            for row in sheet.iter_rows(values_only=True):
+                if all(v is None for v in row):
+                    continue
+                if keys == []:
+                    keys = list(map(str, row))
+                else:
+                    row_dict = dict(zip(keys, row))
+                    row_dict = {k: v for k, v in row_dict.items() if v}
+                    data.append(json.dumps(row_dict, ensure_ascii=False))
+
+        return [Document(page_content='\n\n'.join(data))]

+ 35 - 0
api/core/data_loader/loader/html.py

@@ -0,0 +1,35 @@
+import logging
+from typing import List
+
+from bs4 import BeautifulSoup
+from langchain.document_loaders.base import BaseLoader
+from langchain.schema import Document
+
+logger = logging.getLogger(__name__)
+
+
+class HTMLLoader(BaseLoader):
+    """Load html files.
+
+
+    Args:
+        file_path: Path to the file to load.
+    """
+
+    def __init__(
+        self,
+        file_path: str
+    ):
+        """Initialize with file path."""
+        self._file_path = file_path
+
+    def load(self) -> List[Document]:
+        return [Document(page_content=self._load_as_text())]
+
+    def _load_as_text(self) -> str:
+        with open(self._file_path, "rb") as fp:
+            soup = BeautifulSoup(fp, 'html.parser')
+            text = soup.get_text()
+            text = text.strip() if text else ''
+
+        return text

+ 134 - 0
api/core/data_loader/loader/markdown.py

@@ -0,0 +1,134 @@
+import logging
+import re
+from typing import Optional, List, Tuple, cast
+
+from langchain.document_loaders.base import BaseLoader
+from langchain.document_loaders.helpers import detect_file_encodings
+from langchain.schema import Document
+
+logger = logging.getLogger(__name__)
+
+
+class MarkdownLoader(BaseLoader):
+    """Load md files.
+
+
+    Args:
+        file_path: Path to the file to load.
+
+        remove_hyperlinks: Whether to remove hyperlinks from the text.
+
+        remove_images: Whether to remove images from the text.
+
+        encoding: File encoding to use. If `None`, the file will be loaded
+        with the default system encoding.
+
+        autodetect_encoding: Whether to try to autodetect the file encoding
+            if the specified encoding fails.
+    """
+
+    def __init__(
+        self,
+        file_path: str,
+        remove_hyperlinks: bool = True,
+        remove_images: bool = True,
+        encoding: Optional[str] = None,
+        autodetect_encoding: bool = True,
+    ):
+        """Initialize with file path."""
+        self._file_path = file_path
+        self._remove_hyperlinks = remove_hyperlinks
+        self._remove_images = remove_images
+        self._encoding = encoding
+        self._autodetect_encoding = autodetect_encoding
+
+    def load(self) -> List[Document]:
+        tups = self.parse_tups(self._file_path)
+        documents = []
+        for header, value in tups:
+            value = value.strip()
+            if header is None:
+                documents.append(Document(page_content=value))
+            else:
+                documents.append(Document(page_content=f"\n\n{header}\n{value}"))
+
+        return documents
+
+    def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]:
+        """Convert a markdown file to a dictionary.
+
+        The keys are the headers and the values are the text under each header.
+
+        """
+        markdown_tups: List[Tuple[Optional[str], str]] = []
+        lines = markdown_text.split("\n")
+
+        current_header = None
+        current_text = ""
+
+        for line in lines:
+            header_match = re.match(r"^#+\s", line)
+            if header_match:
+                if current_header is not None:
+                    markdown_tups.append((current_header, current_text))
+
+                current_header = line
+                current_text = ""
+            else:
+                current_text += line + "\n"
+        markdown_tups.append((current_header, current_text))
+
+        if current_header is not None:
+            # pass linting, assert keys are defined
+            markdown_tups = [
+                (re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value))
+                for key, value in markdown_tups
+            ]
+        else:
+            markdown_tups = [
+                (key, re.sub("\n", "", value)) for key, value in markdown_tups
+            ]
+
+        return markdown_tups
+
+    def remove_images(self, content: str) -> str:
+        """Get a dictionary of a markdown file from its path."""
+        pattern = r"!{1}\[\[(.*)\]\]"
+        content = re.sub(pattern, "", content)
+        return content
+
+    def remove_hyperlinks(self, content: str) -> str:
+        """Get a dictionary of a markdown file from its path."""
+        pattern = r"\[(.*?)\]\((.*?)\)"
+        content = re.sub(pattern, r"\1", content)
+        return content
+
+    def parse_tups(self, filepath: str) -> List[Tuple[Optional[str], str]]:
+        """Parse file into tuples."""
+        content = ""
+        try:
+            with open(filepath, "r", encoding=self._encoding) as f:
+                content = f.read()
+        except UnicodeDecodeError as e:
+            if self._autodetect_encoding:
+                detected_encodings = detect_file_encodings(filepath)
+                for encoding in detected_encodings:
+                    logger.debug("Trying encoding: ", encoding.encoding)
+                    try:
+                        with open(filepath, encoding=encoding.encoding) as f:
+                            content = f.read()
+                        break
+                    except UnicodeDecodeError:
+                        continue
+            else:
+                raise RuntimeError(f"Error loading {filepath}") from e
+        except Exception as e:
+            raise RuntimeError(f"Error loading {filepath}") from e
+
+        if self._remove_hyperlinks:
+            content = self.remove_hyperlinks(content)
+
+        if self._remove_images:
+            content = self.remove_images(content)
+
+        return self.markdown_to_tups(content)

+ 236 - 236
api/core/data_source/notion.py → api/core/data_loader/loader/notion.py

@@ -1,68 +1,162 @@
-"""Notion reader."""
 import json
 import logging
-import os
-from datetime import datetime
-from typing import Any, Dict, List, Optional
+from typing import List, Dict, Any, Optional
 
-import requests  # type: ignore
+import requests
+from flask import current_app
+from langchain.document_loaders.base import BaseLoader
+from langchain.schema import Document
 
-from llama_index.readers.base import BaseReader
-from llama_index.readers.schema.base import Document
+from extensions.ext_database import db
+from models.dataset import Document as DocumentModel
+from models.source import DataSourceBinding
+
+logger = logging.getLogger(__name__)
 
-INTEGRATION_TOKEN_NAME = "NOTION_INTEGRATION_TOKEN"
 BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"
 DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query"
 SEARCH_URL = "https://api.notion.com/v1/search"
 RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}"
 RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
 HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']
-logger = logging.getLogger(__name__)
-
 
-# TODO: Notion DB reader coming soon!
-class NotionPageReader(BaseReader):
-    """Notion Page reader.
 
-    Reads a set of Notion pages.
-
-    Args:
-        integration_token (str): Notion integration token.
-
-    """
-
-    def __init__(self, integration_token: Optional[str] = None) -> None:
-        """Initialize with parameters."""
-        if integration_token is None:
-            integration_token = os.getenv(INTEGRATION_TOKEN_NAME)
+class NotionLoader(BaseLoader):
+    def __init__(
+            self,
+            notion_access_token: str,
+            notion_workspace_id: str,
+            notion_obj_id: str,
+            notion_page_type: str,
+            document_model: Optional[DocumentModel] = None
+    ):
+        self._document_model = document_model
+        self._notion_workspace_id = notion_workspace_id
+        self._notion_obj_id = notion_obj_id
+        self._notion_page_type = notion_page_type
+        self._notion_access_token = notion_access_token
+
+        if not self._notion_access_token:
+            integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN')
             if integration_token is None:
                 raise ValueError(
                     "Must specify `integration_token` or set environment "
                     "variable `NOTION_INTEGRATION_TOKEN`."
                 )
-        self.token = integration_token
-        self.headers = {
-            "Authorization": "Bearer " + self.token,
-            "Content-Type": "application/json",
-            "Notion-Version": "2022-06-28",
-        }
 
-    def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
-        """Read a block."""
-        done = False
+            self._notion_access_token = integration_token
+
+    @classmethod
+    def from_document(cls, document_model: DocumentModel):
+        data_source_info = document_model.data_source_info_dict
+        if not data_source_info or 'notion_page_id' not in data_source_info \
+                or 'notion_workspace_id' not in data_source_info:
+            raise ValueError("no notion page found")
+
+        notion_workspace_id = data_source_info['notion_workspace_id']
+        notion_obj_id = data_source_info['notion_page_id']
+        notion_page_type = data_source_info['type']
+        notion_access_token = cls._get_access_token(document_model.tenant_id, notion_workspace_id)
+
+        return cls(
+            notion_access_token=notion_access_token,
+            notion_workspace_id=notion_workspace_id,
+            notion_obj_id=notion_obj_id,
+            notion_page_type=notion_page_type,
+            document_model=document_model
+        )
+
+    def load(self) -> List[Document]:
+        self.update_last_edited_time(
+            self._document_model
+        )
+
+        text_docs = self._load_data_as_documents(self._notion_obj_id, self._notion_page_type)
+
+        return text_docs
+
+    def _load_data_as_documents(
+            self, notion_obj_id: str, notion_page_type: str
+    ) -> List[Document]:
+        docs = []
+        if notion_page_type == 'database':
+            # get all the pages in the database
+            page_text = self._get_notion_database_data(notion_obj_id)
+            docs.append(Document(page_content=page_text))
+        elif notion_page_type == 'page':
+            page_text_list = self._get_notion_block_data(notion_obj_id)
+            for page_text in page_text_list:
+                docs.append(Document(page_content=page_text))
+        else:
+            raise ValueError("notion page type not supported")
+
+        return docs
+
+    def _get_notion_database_data(
+            self, database_id: str, query_dict: Dict[str, Any] = {}
+    ) -> str:
+        """Get all the pages from a Notion database."""
+        res = requests.post(
+            DATABASE_URL_TMPL.format(database_id=database_id),
+            headers={
+                "Authorization": "Bearer " + self._notion_access_token,
+                "Content-Type": "application/json",
+                "Notion-Version": "2022-06-28",
+            },
+            json=query_dict,
+        )
+
+        data = res.json()
+
+        database_content_list = []
+        if 'results' not in data or data["results"] is None:
+            return ""
+        for result in data["results"]:
+            properties = result['properties']
+            data = {}
+            for property_name, property_value in properties.items():
+                type = property_value['type']
+                if type == 'multi_select':
+                    value = []
+                    multi_select_list = property_value[type]
+                    for multi_select in multi_select_list:
+                        value.append(multi_select['name'])
+                elif type == 'rich_text' or type == 'title':
+                    if len(property_value[type]) > 0:
+                        value = property_value[type][0]['plain_text']
+                    else:
+                        value = ''
+                elif type == 'select' or type == 'status':
+                    if property_value[type]:
+                        value = property_value[type]['name']
+                    else:
+                        value = ''
+                else:
+                    value = property_value[type]
+                data[property_name] = value
+            database_content_list.append(json.dumps(data, ensure_ascii=False))
+
+        return "\n\n".join(database_content_list)
+
+    def _get_notion_block_data(self, page_id: str) -> List[str]:
         result_lines_arr = []
-        cur_block_id = block_id
-        while not done:
+        cur_block_id = page_id
+        while True:
             block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
             query_dict: Dict[str, Any] = {}
 
             res = requests.request(
-                "GET", block_url, headers=self.headers, json=query_dict
+                "GET",
+                block_url,
+                headers={
+                    "Authorization": "Bearer " + self._notion_access_token,
+                    "Content-Type": "application/json",
+                    "Notion-Version": "2022-06-28",
+                },
+                json=query_dict
             )
             data = res.json()
-            if 'results' not in data or data["results"] is None:
-                done = True
-                break
+            # current block's heading
             heading = ''
             for result in data["results"]:
                 result_type = result["type"]
@@ -71,6 +165,7 @@ class NotionPageReader(BaseReader):
                 if result_type == 'table':
                     result_block_id = result["id"]
                     text = self._read_table_rows(result_block_id)
+                    text += "\n\n"
                     result_lines_arr.append(text)
                 else:
                     if "rich_text" in result_obj:
@@ -78,91 +173,53 @@ class NotionPageReader(BaseReader):
                             # skip if doesn't have text object
                             if "text" in rich_text:
                                 text = rich_text["text"]["content"]
-                                prefix = "\t" * num_tabs
-                                cur_result_text_arr.append(prefix + text)
+                                cur_result_text_arr.append(text)
                                 if result_type in HEADING_TYPE:
                                     heading = text
+
                     result_block_id = result["id"]
                     has_children = result["has_children"]
                     block_type = result["type"]
                     if has_children and block_type != 'child_page':
                         children_text = self._read_block(
-                            result_block_id, num_tabs=num_tabs + 1
+                            result_block_id, num_tabs=1
                         )
                         cur_result_text_arr.append(children_text)
 
                     cur_result_text = "\n".join(cur_result_text_arr)
+                    cur_result_text += "\n\n"
                     if result_type in HEADING_TYPE:
                         result_lines_arr.append(cur_result_text)
                     else:
                         result_lines_arr.append(f'{heading}\n{cur_result_text}')
 
             if data["next_cursor"] is None:
-                done = True
-                break
-            else:
-                cur_block_id = data["next_cursor"]
-
-        result_lines = "\n".join(result_lines_arr)
-        return result_lines
-
-    def _read_table_rows(self, block_id: str) -> str:
-        """Read table rows."""
-        done = False
-        result_lines_arr = []
-        cur_block_id = block_id
-        while not done:
-            block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
-            query_dict: Dict[str, Any] = {}
-
-            res = requests.request(
-                "GET", block_url, headers=self.headers, json=query_dict
-            )
-            data = res.json()
-            # get table headers text
-            table_header_cell_texts = []
-            tabel_header_cells = data["results"][0]['table_row']['cells']
-            for tabel_header_cell in tabel_header_cells:
-                if tabel_header_cell:
-                    for table_header_cell_text in tabel_header_cell:
-                        text = table_header_cell_text["text"]["content"]
-                        table_header_cell_texts.append(text)
-            # get table columns text and format
-            results = data["results"]
-            for i in range(len(results)-1):
-                column_texts = []
-                tabel_column_cells = data["results"][i+1]['table_row']['cells']
-                for j in range(len(tabel_column_cells)):
-                    if tabel_column_cells[j]:
-                        for table_column_cell_text in tabel_column_cells[j]:
-                            column_text = table_column_cell_text["text"]["content"]
-                            column_texts.append(f'{table_header_cell_texts[j]}:{column_text}')
-
-                cur_result_text = "\n".join(column_texts)
-                result_lines_arr.append(cur_result_text)
-
-            if data["next_cursor"] is None:
-                done = True
                 break
             else:
                 cur_block_id = data["next_cursor"]
+        return result_lines_arr
 
-        result_lines = "\n".join(result_lines_arr)
-        return result_lines
-    def _read_parent_blocks(self, block_id: str, num_tabs: int = 0) -> List[str]:
+    def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
         """Read a block."""
-        done = False
         result_lines_arr = []
         cur_block_id = block_id
-        while not done:
+        while True:
             block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
             query_dict: Dict[str, Any] = {}
 
             res = requests.request(
-                "GET", block_url, headers=self.headers, json=query_dict
+                "GET",
+                block_url,
+                headers={
+                    "Authorization": "Bearer " + self._notion_access_token,
+                    "Content-Type": "application/json",
+                    "Notion-Version": "2022-06-28",
+                },
+                json=query_dict
             )
             data = res.json()
-            # current block's heading
+            if 'results' not in data or data["results"] is None:
+                break
             heading = ''
             for result in data["results"]:
                 result_type = result["type"]
@@ -171,7 +228,6 @@ class NotionPageReader(BaseReader):
                 if result_type == 'table':
                     result_block_id = result["id"]
                     text = self._read_table_rows(result_block_id)
-                    text += "\n\n"
                     result_lines_arr.append(text)
                 else:
                     if "rich_text" in result_obj:
@@ -179,10 +235,10 @@ class NotionPageReader(BaseReader):
                             # skip if doesn't have text object
                             if "text" in rich_text:
                                 text = rich_text["text"]["content"]
-                                cur_result_text_arr.append(text)
+                                prefix = "\t" * num_tabs
+                                cur_result_text_arr.append(prefix + text)
                                 if result_type in HEADING_TYPE:
                                     heading = text
-
                     result_block_id = result["id"]
                     has_children = result["has_children"]
                     block_type = result["type"]
@@ -193,177 +249,121 @@ class NotionPageReader(BaseReader):
                         cur_result_text_arr.append(children_text)
 
                     cur_result_text = "\n".join(cur_result_text_arr)
-                    cur_result_text += "\n\n"
                     if result_type in HEADING_TYPE:
                         result_lines_arr.append(cur_result_text)
                     else:
                         result_lines_arr.append(f'{heading}\n{cur_result_text}')
 
             if data["next_cursor"] is None:
-                done = True
                 break
             else:
                 cur_block_id = data["next_cursor"]
-        return result_lines_arr
-
-    def read_page(self, page_id: str) -> str:
-        """Read a page."""
-        return self._read_block(page_id)
-
-    def read_page_as_documents(self, page_id: str) -> List[str]:
-        """Read a page as documents."""
-        return self._read_parent_blocks(page_id)
-
-    def query_database_data(
-            self, database_id: str, query_dict: Dict[str, Any] = {}
-    ) -> str:
-        """Get all the pages from a Notion database."""
-        res = requests.post\
-                (
-            DATABASE_URL_TMPL.format(database_id=database_id),
-            headers=self.headers,
-            json=query_dict,
-        )
-        data = res.json()
-        database_content_list = []
-        if 'results' not in data or data["results"] is None:
-            return ""
-        for result in data["results"]:
-            properties = result['properties']
-            data = {}
-            for property_name, property_value in properties.items():
-                type = property_value['type']
-                if type == 'multi_select':
-                    value = []
-                    multi_select_list = property_value[type]
-                    for multi_select in multi_select_list:
-                        value.append(multi_select['name'])
-                elif type == 'rich_text' or type == 'title':
-                    if len(property_value[type]) > 0:
-                        value = property_value[type][0]['plain_text']
-                    else:
-                        value = ''
-                elif type == 'select' or type == 'status':
-                    if property_value[type]:
-                        value = property_value[type]['name']
-                    else:
-                        value = ''
-                else:
-                    value = property_value[type]
-                data[property_name] = value
-            database_content_list.append(json.dumps(data, ensure_ascii=False))
-
-        return "\n\n".join(database_content_list)
-
-    def query_database(
-            self, database_id: str, query_dict: Dict[str, Any] = {}
-    ) -> List[str]:
-        """Get all the pages from a Notion database."""
-        res = requests.post\
-                (
-            DATABASE_URL_TMPL.format(database_id=database_id),
-            headers=self.headers,
-            json=query_dict,
-        )
-        data = res.json()
-        page_ids = []
-        for result in data["results"]:
-            page_id = result["id"]
-            page_ids.append(page_id)
 
-        return page_ids
+        result_lines = "\n".join(result_lines_arr)
+        return result_lines
 
-    def search(self, query: str) -> List[str]:
-        """Search Notion page given a text query."""
+    def _read_table_rows(self, block_id: str) -> str:
+        """Read table rows."""
         done = False
-        next_cursor: Optional[str] = None
-        page_ids = []
+        result_lines_arr = []
+        cur_block_id = block_id
         while not done:
-            query_dict = {
-                "query": query,
-            }
-            if next_cursor is not None:
-                query_dict["start_cursor"] = next_cursor
-            res = requests.post(SEARCH_URL, headers=self.headers, json=query_dict)
+            block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
+            query_dict: Dict[str, Any] = {}
+
+            res = requests.request(
+                "GET",
+                block_url,
+                headers={
+                    "Authorization": "Bearer " + self._notion_access_token,
+                    "Content-Type": "application/json",
+                    "Notion-Version": "2022-06-28",
+                },
+                json=query_dict
+            )
             data = res.json()
-            for result in data["results"]:
-                page_id = result["id"]
-                page_ids.append(page_id)
+            # get table headers text
+            table_header_cell_texts = []
+            tabel_header_cells = data["results"][0]['table_row']['cells']
+            for tabel_header_cell in tabel_header_cells:
+                if tabel_header_cell:
+                    for table_header_cell_text in tabel_header_cell:
+                        text = table_header_cell_text["text"]["content"]
+                        table_header_cell_texts.append(text)
+            # get table columns text and format
+            results = data["results"]
+            for i in range(len(results) - 1):
+                column_texts = []
+                tabel_column_cells = data["results"][i + 1]['table_row']['cells']
+                for j in range(len(tabel_column_cells)):
+                    if tabel_column_cells[j]:
+                        for table_column_cell_text in tabel_column_cells[j]:
+                            column_text = table_column_cell_text["text"]["content"]
+                            column_texts.append(f'{table_header_cell_texts[j]}:{column_text}')
+
+                cur_result_text = "\n".join(column_texts)
+                result_lines_arr.append(cur_result_text)
 
             if data["next_cursor"] is None:
                 done = True
                 break
             else:
-                next_cursor = data["next_cursor"]
-        return page_ids
+                cur_block_id = data["next_cursor"]
 
-    def load_data(
-            self, page_ids: List[str] = [], database_id: Optional[str] = None
-    ) -> List[Document]:
-        """Load data from the input directory.
+        result_lines = "\n".join(result_lines_arr)
+        return result_lines
 
-        Args:
-            page_ids (List[str]): List of page ids to load.
+    def update_last_edited_time(self, document_model: DocumentModel):
+        if not document_model:
+            return
 
-        Returns:
-            List[Document]: List of documents.
+        last_edited_time = self.get_notion_last_edited_time()
+        data_source_info = document_model.data_source_info_dict
+        data_source_info['last_edited_time'] = last_edited_time
+        update_params = {
+            DocumentModel.data_source_info: json.dumps(data_source_info)
+        }
 
-        """
-        if not page_ids and not database_id:
-            raise ValueError("Must specify either `page_ids` or `database_id`.")
-        docs = []
-        if database_id is not None:
-            # get all the pages in the database
-            page_ids = self.query_database(database_id)
-            for page_id in page_ids:
-                page_text = self.read_page(page_id)
-                docs.append(Document(page_text))
-        else:
-            for page_id in page_ids:
-                page_text = self.read_page(page_id)
-                docs.append(Document(page_text))
+        DocumentModel.query.filter_by(id=document_model.id).update(update_params)
+        db.session.commit()
 
-        return docs
-
-    def load_data_as_documents(
-            self, page_ids: List[str] = [], database_id: Optional[str] = None
-    ) -> List[Document]:
-        if not page_ids and not database_id:
-            raise ValueError("Must specify either `page_ids` or `database_id`.")
-        docs = []
-        if database_id is not None:
-            # get all the pages in the database
-            page_text = self.query_database_data(database_id)
-            docs.append(Document(page_text))
+    def get_notion_last_edited_time(self) -> str:
+        obj_id = self._notion_obj_id
+        page_type = self._notion_page_type
+        if page_type == 'database':
+            retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=obj_id)
         else:
-            for page_id in page_ids:
-                page_text_list = self.read_page_as_documents(page_id)
-                for page_text in page_text_list:
-                    docs.append(Document(page_text))
+            retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id)
 
-        return docs
-
-    def get_page_last_edited_time(self, page_id: str) -> str:
-        retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=page_id)
         query_dict: Dict[str, Any] = {}
 
         res = requests.request(
-            "GET", retrieve_page_url, headers=self.headers, json=query_dict
+            "GET",
+            retrieve_page_url,
+            headers={
+                "Authorization": "Bearer " + self._notion_access_token,
+                "Content-Type": "application/json",
+                "Notion-Version": "2022-06-28",
+            },
+            json=query_dict
         )
-        data = res.json()
-        return data["last_edited_time"]
-
-    def get_database_last_edited_time(self, database_id: str) -> str:
-        retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=database_id)
-        query_dict: Dict[str, Any] = {}
 
-        res = requests.request(
-            "GET", retrieve_page_url, headers=self.headers, json=query_dict
-        )
         data = res.json()
         return data["last_edited_time"]
 
+    @classmethod
+    def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
+        data_source_binding = DataSourceBinding.query.filter(
+            db.and_(
+                DataSourceBinding.tenant_id == tenant_id,
+                DataSourceBinding.provider == 'notion',
+                DataSourceBinding.disabled == False,
+                DataSourceBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"'
+            )
+        ).first()
+
+        if not data_source_binding:
+            raise Exception(f'No notion data source binding found for tenant {tenant_id} '
+                            f'and notion workspace {notion_workspace_id}')
 
-if __name__ == "__main__":
-    reader = NotionPageReader()
-    logger.info(reader.search("What I"))
+        return data_source_binding.access_token

+ 55 - 0
api/core/data_loader/loader/pdf.py

@@ -0,0 +1,55 @@
+import logging
+from typing import List, Optional
+
+from langchain.document_loaders import PyPDFium2Loader
+from langchain.document_loaders.base import BaseLoader
+from langchain.schema import Document
+
+from extensions.ext_storage import storage
+from models.model import UploadFile
+
+logger = logging.getLogger(__name__)
+
+
+class PdfLoader(BaseLoader):
+    """Load pdf files.
+
+
+    Args:
+        file_path: Path to the file to load.
+    """
+
+    def __init__(
+        self,
+        file_path: str,
+        upload_file: Optional[UploadFile] = None
+    ):
+        """Initialize with file path."""
+        self._file_path = file_path
+        self._upload_file = upload_file
+
+    def load(self) -> List[Document]:
+        plaintext_file_key = ''
+        plaintext_file_exists = False
+        if self._upload_file:
+            if self._upload_file.hash:
+                plaintext_file_key = 'upload_files/' + self._upload_file.tenant_id + '/' \
+                                     + self._upload_file.hash + '.0625.plaintext'
+                try:
+                    text = storage.load(plaintext_file_key).decode('utf-8')
+                    plaintext_file_exists = True
+                    return [Document(page_content=text)]
+                except FileNotFoundError:
+                    pass
+        documents = PyPDFium2Loader(file_path=self._file_path).load()
+        text_list = []
+        for document in documents:
+            text_list.append(document.page_content)
+        text = "\n\n".join(text_list)
+
+        # save plaintext file for caching
+        if not plaintext_file_exists and plaintext_file_key:
+            storage.save(plaintext_file_key, text.encode('utf-8'))
+
+        return documents
+

+ 35 - 45
api/core/docstore/dataset_docstore.py

@@ -1,10 +1,6 @@
 from typing import Any, Dict, Optional, Sequence
 
-import tiktoken
-from llama_index.data_structs import Node
-from llama_index.docstore.types import BaseDocumentStore
-from llama_index.docstore.utils import json_to_doc
-from llama_index.schema import BaseDocument
+from langchain.schema import Document
 from sqlalchemy import func
 
 from core.llm.token_calculator import TokenCalculator
@@ -12,7 +8,7 @@ from extensions.ext_database import db
 from models.dataset import Dataset, DocumentSegment
 
 
-class DatesetDocumentStore(BaseDocumentStore):
+class DatesetDocumentStore:
     def __init__(
         self,
         dataset: Dataset,
@@ -48,7 +44,7 @@ class DatesetDocumentStore(BaseDocumentStore):
         return self._embedding_model_name
 
     @property
-    def docs(self) -> Dict[str, BaseDocument]:
+    def docs(self) -> Dict[str, Document]:
         document_segments = db.session.query(DocumentSegment).filter(
             DocumentSegment.dataset_id == self._dataset.id
         ).all()
@@ -56,13 +52,20 @@ class DatesetDocumentStore(BaseDocumentStore):
         output = {}
         for document_segment in document_segments:
             doc_id = document_segment.index_node_id
-            result = self.segment_to_dict(document_segment)
-            output[doc_id] = json_to_doc(result)
+            output[doc_id] = Document(
+                page_content=document_segment.content,
+                metadata={
+                    "doc_id": document_segment.index_node_id,
+                    "doc_hash": document_segment.index_node_hash,
+                    "document_id": document_segment.document_id,
+                    "dataset_id": document_segment.dataset_id,
+                }
+            )
 
         return output
 
     def add_documents(
-        self, docs: Sequence[BaseDocument], allow_update: bool = True
+        self, docs: Sequence[Document], allow_update: bool = True
     ) -> None:
         max_position = db.session.query(func.max(DocumentSegment.position)).filter(
             DocumentSegment.document == self._document_id
@@ -72,23 +75,20 @@ class DatesetDocumentStore(BaseDocumentStore):
             max_position = 0
 
         for doc in docs:
-            if doc.is_doc_id_none:
-                raise ValueError("doc_id not set")
+            if not isinstance(doc, Document):
+                raise ValueError("doc must be a Document")
 
-            if not isinstance(doc, Node):
-                raise ValueError("doc must be a Node")
-
-            segment_document = self.get_document(doc_id=doc.get_doc_id(), raise_error=False)
+            segment_document = self.get_document(doc_id=doc.metadata['doc_id'], raise_error=False)
 
             # NOTE: doc could already exist in the store, but we overwrite it
             if not allow_update and segment_document:
                 raise ValueError(
-                    f"doc_id {doc.get_doc_id()} already exists. "
+                    f"doc_id {doc.metadata['doc_id']} already exists. "
                     "Set allow_update to True to overwrite."
                 )
 
             # calc embedding use tokens
-            tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.get_text())
+            tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.page_content)
 
             if not segment_document:
                 max_position += 1
@@ -97,19 +97,19 @@ class DatesetDocumentStore(BaseDocumentStore):
                     tenant_id=self._dataset.tenant_id,
                     dataset_id=self._dataset.id,
                     document_id=self._document_id,
-                    index_node_id=doc.get_doc_id(),
-                    index_node_hash=doc.get_doc_hash(),
+                    index_node_id=doc.metadata['doc_id'],
+                    index_node_hash=doc.metadata['doc_hash'],
                     position=max_position,
-                    content=doc.get_text(),
-                    word_count=len(doc.get_text()),
+                    content=doc.page_content,
+                    word_count=len(doc.page_content),
                     tokens=tokens,
                     created_by=self._user_id,
                 )
                 db.session.add(segment_document)
             else:
-                segment_document.content = doc.get_text()
-                segment_document.index_node_hash = doc.get_doc_hash()
-                segment_document.word_count = len(doc.get_text())
+                segment_document.content = doc.page_content
+                segment_document.index_node_hash = doc.metadata['doc_hash']
+                segment_document.word_count = len(doc.page_content)
                 segment_document.tokens = tokens
 
             db.session.commit()
@@ -121,7 +121,7 @@ class DatesetDocumentStore(BaseDocumentStore):
 
     def get_document(
         self, doc_id: str, raise_error: bool = True
-    ) -> Optional[BaseDocument]:
+    ) -> Optional[Document]:
         document_segment = self.get_document_segment(doc_id)
 
         if document_segment is None:
@@ -130,8 +130,15 @@ class DatesetDocumentStore(BaseDocumentStore):
             else:
                 return None
 
-        result = self.segment_to_dict(document_segment)
-        return json_to_doc(result)
+        return Document(
+            page_content=document_segment.content,
+            metadata={
+                "doc_id": document_segment.index_node_id,
+                "doc_hash": document_segment.index_node_hash,
+                "document_id": document_segment.document_id,
+                "dataset_id": document_segment.dataset_id,
+            }
+        )
 
     def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
         document_segment = self.get_document_segment(doc_id)
@@ -164,15 +171,6 @@ class DatesetDocumentStore(BaseDocumentStore):
 
         return document_segment.index_node_hash
 
-    def update_docstore(self, other: "BaseDocumentStore") -> None:
-        """Update docstore.
-
-        Args:
-            other (BaseDocumentStore): docstore to update from
-
-        """
-        self.add_documents(list(other.docs.values()))
-
     def get_document_segment(self, doc_id: str) -> DocumentSegment:
         document_segment = db.session.query(DocumentSegment).filter(
             DocumentSegment.dataset_id == self._dataset.id,
@@ -180,11 +178,3 @@ class DatesetDocumentStore(BaseDocumentStore):
         ).first()
 
         return document_segment
-
-    def segment_to_dict(self, segment: DocumentSegment) -> Dict[str, Any]:
-        return {
-            "doc_id": segment.index_node_id,
-            "doc_hash": segment.index_node_hash,
-            "text": segment.content,
-            "__type__": Node.get_type()
-        }

+ 0 - 51
api/core/docstore/empty_docstore.py

@@ -1,51 +0,0 @@
-from typing import Any, Dict, Optional, Sequence
-from llama_index.docstore.types import BaseDocumentStore
-from llama_index.schema import BaseDocument
-
-
-class EmptyDocumentStore(BaseDocumentStore):
-    @classmethod
-    def from_dict(cls, config_dict: Dict[str, Any]) -> "EmptyDocumentStore":
-        return cls()
-
-    def to_dict(self) -> Dict[str, Any]:
-        """Serialize to dict."""
-        return {}
-
-    @property
-    def docs(self) -> Dict[str, BaseDocument]:
-        return {}
-
-    def add_documents(
-        self, docs: Sequence[BaseDocument], allow_update: bool = True
-    ) -> None:
-        pass
-
-    def document_exists(self, doc_id: str) -> bool:
-        """Check if document exists."""
-        return False
-
-    def get_document(
-        self, doc_id: str, raise_error: bool = True
-    ) -> Optional[BaseDocument]:
-        return None
-
-    def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
-        pass
-
-    def set_document_hash(self, doc_id: str, doc_hash: str) -> None:
-        """Set the hash for a given doc_id."""
-        pass
-
-    def get_document_hash(self, doc_id: str) -> Optional[str]:
-        """Get the stored hash for a document, if it exists."""
-        return None
-
-    def update_docstore(self, other: "BaseDocumentStore") -> None:
-        """Update docstore.
-
-        Args:
-            other (BaseDocumentStore): docstore to update from
-
-        """
-        self.add_documents(list(other.docs.values()))

+ 72 - 0
api/core/embedding/cached_embedding.py

@@ -0,0 +1,72 @@
+import logging
+from typing import List
+
+from langchain.embeddings.base import Embeddings
+from sqlalchemy.exc import IntegrityError
+
+from extensions.ext_database import db
+from libs import helper
+from models.dataset import Embedding
+
+
+class CacheEmbedding(Embeddings):
+    def __init__(self, embeddings: Embeddings):
+        self._embeddings = embeddings
+
+    def embed_documents(self, texts: List[str]) -> List[List[float]]:
+        """Embed search docs."""
+        # use doc embedding cache or store if not exists
+        text_embeddings = []
+        embedding_queue_texts = []
+        for text in texts:
+            hash = helper.generate_text_hash(text)
+            embedding = db.session.query(Embedding).filter_by(hash=hash).first()
+            if embedding:
+                text_embeddings.append(embedding.get_embedding())
+            else:
+                embedding_queue_texts.append(text)
+
+        embedding_results = self._embeddings.embed_documents(embedding_queue_texts)
+
+        i = 0
+        for text in embedding_queue_texts:
+            hash = helper.generate_text_hash(text)
+
+            try:
+                embedding = Embedding(hash=hash)
+                embedding.set_embedding(embedding_results[i])
+                db.session.add(embedding)
+                db.session.commit()
+            except IntegrityError:
+                db.session.rollback()
+                continue
+            except:
+                logging.exception('Failed to add embedding to db')
+                continue
+
+            i += 1
+
+        text_embeddings.extend(embedding_results)
+        return text_embeddings
+
+    def embed_query(self, text: str) -> List[float]:
+        """Embed query text."""
+        # use doc embedding cache or store if not exists
+        hash = helper.generate_text_hash(text)
+        embedding = db.session.query(Embedding).filter_by(hash=hash).first()
+        if embedding:
+            return embedding.get_embedding()
+
+        embedding_results = self._embeddings.embed_query(text)
+
+        try:
+            embedding = Embedding(hash=hash)
+            embedding.set_embedding(embedding_results)
+            db.session.add(embedding)
+            db.session.commit()
+        except IntegrityError:
+            db.session.rollback()
+        except:
+            logging.exception('Failed to add embedding to db')
+
+        return embedding_results

+ 0 - 214
api/core/embedding/openai_embedding.py

@@ -1,214 +0,0 @@
-from typing import Optional, Any, List
-
-import openai
-from llama_index.embeddings.base import BaseEmbedding
-from llama_index.embeddings.openai import OpenAIEmbeddingMode, OpenAIEmbeddingModelType, _QUERY_MODE_MODEL_DICT, \
-    _TEXT_MODE_MODEL_DICT
-from tenacity import wait_random_exponential, retry, stop_after_attempt
-
-from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
-
-
-@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
-def get_embedding(
-        text: str,
-        engine: Optional[str] = None,
-        api_key: Optional[str] = None,
-        **kwargs
-) -> List[float]:
-    """Get embedding.
-
-    NOTE: Copied from OpenAI's embedding utils:
-    https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
-
-    Copied here to avoid importing unnecessary dependencies
-    like matplotlib, plotly, scipy, sklearn.
-
-    """
-    text = text.replace("\n", " ")
-    return openai.Embedding.create(input=[text], engine=engine, api_key=api_key, **kwargs)["data"][0]["embedding"]
-
-
-@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
-async def aget_embedding(text: str, engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs) -> List[
-    float]:
-    """Asynchronously get embedding.
-
-    NOTE: Copied from OpenAI's embedding utils:
-    https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
-
-    Copied here to avoid importing unnecessary dependencies
-    like matplotlib, plotly, scipy, sklearn.
-
-    """
-    # replace newlines, which can negatively affect performance.
-    text = text.replace("\n", " ")
-
-    return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=api_key, **kwargs))["data"][0][
-        "embedding"
-    ]
-
-
-@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
-def get_embeddings(
-        list_of_text: List[str],
-        engine: Optional[str] = None,
-        api_key: Optional[str] = None,
-        **kwargs
-) -> List[List[float]]:
-    """Get embeddings.
-
-    NOTE: Copied from OpenAI's embedding utils:
-    https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
-
-    Copied here to avoid importing unnecessary dependencies
-    like matplotlib, plotly, scipy, sklearn.
-
-    """
-    assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
-
-    # replace newlines, which can negatively affect performance.
-    list_of_text = [text.replace("\n", " ") for text in list_of_text]
-
-    data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=api_key, **kwargs).data
-    data = sorted(data, key=lambda x: x["index"])  # maintain the same order as input.
-    return [d["embedding"] for d in data]
-
-
-@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
-async def aget_embeddings(
-        list_of_text: List[str], engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs
-) -> List[List[float]]:
-    """Asynchronously get embeddings.
-
-    NOTE: Copied from OpenAI's embedding utils:
-    https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
-
-    Copied here to avoid importing unnecessary dependencies
-    like matplotlib, plotly, scipy, sklearn.
-
-    """
-    assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
-
-    # replace newlines, which can negatively affect performance.
-    list_of_text = [text.replace("\n", " ") for text in list_of_text]
-
-    data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=api_key, **kwargs)).data
-    data = sorted(data, key=lambda x: x["index"])  # maintain the same order as input.
-    return [d["embedding"] for d in data]
-
-
-class OpenAIEmbedding(BaseEmbedding):
-
-    def __init__(
-            self,
-            mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
-            model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,
-            deployment_name: Optional[str] = None,
-            openai_api_key: Optional[str] = None,
-            **kwargs: Any,
-    ) -> None:
-        """Init params."""
-        new_kwargs = {}
-
-        if 'embed_batch_size' in kwargs:
-            new_kwargs['embed_batch_size'] = kwargs['embed_batch_size']
-
-        if 'tokenizer' in kwargs:
-            new_kwargs['tokenizer'] = kwargs['tokenizer']
-
-        super().__init__(**new_kwargs)
-        self.mode = OpenAIEmbeddingMode(mode)
-        self.model = OpenAIEmbeddingModelType(model)
-        self.deployment_name = deployment_name
-        self.openai_api_key = openai_api_key
-        self.openai_api_type = kwargs.get('openai_api_type')
-        self.openai_api_version = kwargs.get('openai_api_version')
-        self.openai_api_base = kwargs.get('openai_api_base')
-
-    @handle_llm_exceptions
-    def _get_query_embedding(self, query: str) -> List[float]:
-        """Get query embedding."""
-        if self.deployment_name is not None:
-            engine = self.deployment_name
-        else:
-            key = (self.mode, self.model)
-            if key not in _QUERY_MODE_MODEL_DICT:
-                raise ValueError(f"Invalid mode, model combination: {key}")
-            engine = _QUERY_MODE_MODEL_DICT[key]
-        return get_embedding(query, engine=engine, api_key=self.openai_api_key,
-                             api_type=self.openai_api_type, api_version=self.openai_api_version,
-                             api_base=self.openai_api_base)
-
-    def _get_text_embedding(self, text: str) -> List[float]:
-        """Get text embedding."""
-        if self.deployment_name is not None:
-            engine = self.deployment_name
-        else:
-            key = (self.mode, self.model)
-            if key not in _TEXT_MODE_MODEL_DICT:
-                raise ValueError(f"Invalid mode, model combination: {key}")
-            engine = _TEXT_MODE_MODEL_DICT[key]
-        return get_embedding(text, engine=engine, api_key=self.openai_api_key,
-                             api_type=self.openai_api_type, api_version=self.openai_api_version,
-                             api_base=self.openai_api_base)
-
-    async def _aget_text_embedding(self, text: str) -> List[float]:
-        """Asynchronously get text embedding."""
-        if self.deployment_name is not None:
-            engine = self.deployment_name
-        else:
-            key = (self.mode, self.model)
-            if key not in _TEXT_MODE_MODEL_DICT:
-                raise ValueError(f"Invalid mode, model combination: {key}")
-            engine = _TEXT_MODE_MODEL_DICT[key]
-        return await aget_embedding(text, engine=engine, api_key=self.openai_api_key,
-                                    api_type=self.openai_api_type, api_version=self.openai_api_version,
-                                    api_base=self.openai_api_base)
-
-    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
-        """Get text embeddings.
-
-        By default, this is a wrapper around _get_text_embedding.
-        Can be overriden for batch queries.
-
-        """
-        if self.openai_api_type and self.openai_api_type == 'azure':
-            embeddings = []
-            for text in texts:
-                embeddings.append(self._get_text_embedding(text))
-
-            return embeddings
-
-        if self.deployment_name is not None:
-            engine = self.deployment_name
-        else:
-            key = (self.mode, self.model)
-            if key not in _TEXT_MODE_MODEL_DICT:
-                raise ValueError(f"Invalid mode, model combination: {key}")
-            engine = _TEXT_MODE_MODEL_DICT[key]
-        embeddings = get_embeddings(texts, engine=engine, api_key=self.openai_api_key,
-                                    api_type=self.openai_api_type, api_version=self.openai_api_version,
-                                    api_base=self.openai_api_base)
-        return embeddings
-
-    async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
-        """Asynchronously get text embeddings."""
-        if self.openai_api_type and self.openai_api_type == 'azure':
-            embeddings = []
-            for text in texts:
-                embeddings.append(await self._aget_text_embedding(text))
-
-            return embeddings
-
-        if self.deployment_name is not None:
-            engine = self.deployment_name
-        else:
-            key = (self.mode, self.model)
-            if key not in _TEXT_MODE_MODEL_DICT:
-                raise ValueError(f"Invalid mode, model combination: {key}")
-            engine = _TEXT_MODE_MODEL_DICT[key]
-        embeddings = await aget_embeddings(texts, engine=engine, api_key=self.openai_api_key,
-                                           api_type=self.openai_api_type, api_version=self.openai_api_version,
-                                           api_base=self.openai_api_base)
-        return embeddings

+ 59 - 0
api/core/index/base.py

@@ -0,0 +1,59 @@
+from __future__ import annotations
+from abc import abstractmethod, ABC
+from typing import List, Any
+
+from langchain.schema import Document, BaseRetriever
+
+from models.dataset import Dataset
+
+
+class BaseIndex(ABC):
+
+    def __init__(self, dataset: Dataset):
+        self.dataset = dataset
+
+    @abstractmethod
+    def create(self, texts: list[Document], **kwargs) -> BaseIndex:
+        raise NotImplementedError
+
+    @abstractmethod
+    def add_texts(self, texts: list[Document], **kwargs):
+        raise NotImplementedError
+
+    @abstractmethod
+    def text_exists(self, id: str) -> bool:
+        raise NotImplementedError
+
+    @abstractmethod
+    def delete_by_ids(self, ids: list[str]) -> None:
+        raise NotImplementedError
+
+    @abstractmethod
+    def delete_by_document_id(self, document_id: str):
+        raise NotImplementedError
+
+    @abstractmethod
+    def get_retriever(self, **kwargs: Any) -> BaseRetriever:
+        raise NotImplementedError
+
+    @abstractmethod
+    def search(
+            self, query: str,
+            **kwargs: Any
+    ) -> List[Document]:
+        raise NotImplementedError
+
+    def delete(self) -> None:
+        raise NotImplementedError
+
+    def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
+        for text in texts:
+            doc_id = text.metadata['doc_id']
+            exists_duplicate_node = self.text_exists(doc_id)
+            if exists_duplicate_node:
+                texts.remove(text)
+
+        return texts
+
+    def _get_uuids(self, texts: list[Document]) -> list[str]:
+        return [text.metadata['doc_id'] for text in texts]

+ 41 - 0
api/core/index/index.py

@@ -0,0 +1,41 @@
+from flask import current_app
+from langchain.embeddings import OpenAIEmbeddings
+
+from core.embedding.cached_embedding import CacheEmbedding
+from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
+from core.index.vector_index.vector_index import VectorIndex
+from core.llm.llm_builder import LLMBuilder
+from models.dataset import Dataset
+
+
+class IndexBuilder:
+    @classmethod
+    def get_index(cls, dataset: Dataset, indexing_technique: str, ignore_high_quality_check: bool = False):
+        if indexing_technique == "high_quality":
+            if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
+                return None
+
+            model_credentials = LLMBuilder.get_model_credentials(
+                tenant_id=dataset.tenant_id,
+                model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
+                model_name='text-embedding-ada-002'
+            )
+
+            embeddings = CacheEmbedding(OpenAIEmbeddings(
+                **model_credentials
+            ))
+
+            return VectorIndex(
+                dataset=dataset,
+                config=current_app.config,
+                embeddings=embeddings
+            )
+        elif indexing_technique == "economy":
+            return KeywordTableIndex(
+                dataset=dataset,
+                config=KeywordTableConfig(
+                    max_keywords_per_chunk=10
+                )
+            )
+        else:
+            raise ValueError('Unknown indexing technique')

+ 0 - 60
api/core/index/index_builder.py

@@ -1,60 +0,0 @@
-from langchain.callbacks import CallbackManager
-from llama_index import ServiceContext, PromptHelper, LLMPredictor
-from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
-from core.embedding.openai_embedding import OpenAIEmbedding
-from core.llm.llm_builder import LLMBuilder
-
-
-class IndexBuilder:
-    @classmethod
-    def get_default_service_context(cls, tenant_id: str) -> ServiceContext:
-        # set number of output tokens
-        num_output = 512
-
-        # only for verbose
-        callback_manager = CallbackManager([DifyStdOutCallbackHandler()])
-
-        llm = LLMBuilder.to_llm(
-            tenant_id=tenant_id,
-            model_name='text-davinci-003',
-            temperature=0,
-            max_tokens=num_output,
-            callback_manager=callback_manager,
-        )
-
-        llm_predictor = LLMPredictor(llm=llm)
-
-        # These parameters here will affect the logic of segmenting the final synthesized response.
-        # The number of refinement iterations in the synthesis process depends
-        # on whether the length of the segmented output exceeds the max_input_size.
-        prompt_helper = PromptHelper(
-            max_input_size=3500,
-            num_output=num_output,
-            max_chunk_overlap=20
-        )
-
-        provider = LLMBuilder.get_default_provider(tenant_id)
-
-        model_credentials = LLMBuilder.get_model_credentials(
-            tenant_id=tenant_id,
-            model_provider=provider,
-            model_name='text-embedding-ada-002'
-        )
-
-        return ServiceContext.from_defaults(
-            llm_predictor=llm_predictor,
-            prompt_helper=prompt_helper,
-            embed_model=OpenAIEmbedding(**model_credentials),
-        )
-
-    @classmethod
-    def get_fake_llm_service_context(cls, tenant_id: str) -> ServiceContext:
-        llm = LLMBuilder.to_llm(
-            tenant_id=tenant_id,
-            model_name='fake'
-        )
-
-        return ServiceContext.from_defaults(
-            llm_predictor=LLMPredictor(llm=llm),
-            embed_model=OpenAIEmbedding()
-        )

+ 0 - 159
api/core/index/keyword_table/jieba_keyword_table.py

@@ -1,159 +0,0 @@
-import re
-from typing import (
-    Any,
-    Dict,
-    List,
-    Set,
-    Optional
-)
-
-import jieba.analyse
-
-from core.index.keyword_table.stopwords import STOPWORDS
-from llama_index.indices.query.base import IS
-from llama_index import QueryMode
-from llama_index.indices.base import QueryMap
-from llama_index.indices.keyword_table.base import BaseGPTKeywordTableIndex
-from llama_index.indices.keyword_table.query import BaseGPTKeywordTableQuery
-from llama_index.docstore import BaseDocumentStore
-from llama_index.indices.postprocessor.node import (
-    BaseNodePostprocessor,
-)
-from llama_index.indices.response.response_builder import ResponseMode
-from llama_index.indices.service_context import ServiceContext
-from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
-from llama_index.prompts.prompts import (
-    QuestionAnswerPrompt,
-    RefinePrompt,
-    SimpleInputPrompt,
-)
-
-from core.index.query.synthesizer import EnhanceResponseSynthesizer
-
-
-def jieba_extract_keywords(
-        text_chunk: str,
-        max_keywords: Optional[int] = None,
-        expand_with_subtokens: bool = True,
-) -> Set[str]:
-    """Extract keywords with JIEBA tfidf."""
-    keywords = jieba.analyse.extract_tags(
-        sentence=text_chunk,
-        topK=max_keywords,
-    )
-
-    if expand_with_subtokens:
-        return set(expand_tokens_with_subtokens(keywords))
-    else:
-        return set(keywords)
-
-
-def expand_tokens_with_subtokens(tokens: Set[str]) -> Set[str]:
-    """Get subtokens from a list of tokens., filtering for stopwords."""
-    results = set()
-    for token in tokens:
-        results.add(token)
-        sub_tokens = re.findall(r"\w+", token)
-        if len(sub_tokens) > 1:
-            results.update({w for w in sub_tokens if w not in list(STOPWORDS)})
-
-    return results
-
-
-class GPTJIEBAKeywordTableIndex(BaseGPTKeywordTableIndex):
-    """GPT JIEBA Keyword Table Index.
-
-    This index uses a JIEBA keyword extractor to extract keywords from the text.
-
-    """
-
-    def _extract_keywords(self, text: str) -> Set[str]:
-        """Extract keywords from text."""
-        return jieba_extract_keywords(text, max_keywords=self.max_keywords_per_chunk)
-
-    @classmethod
-    def get_query_map(self) -> QueryMap:
-        """Get query map."""
-        super_map = super().get_query_map()
-        super_map[QueryMode.DEFAULT] = GPTKeywordTableJIEBAQuery
-        return super_map
-
-    def _delete(self, doc_id: str, **delete_kwargs: Any) -> None:
-        """Delete a document."""
-        # get set of ids that correspond to node
-        node_idxs_to_delete = {doc_id}
-
-        # delete node_idxs from keyword to node idxs mapping
-        keywords_to_delete = set()
-        for keyword, node_idxs in self._index_struct.table.items():
-            if node_idxs_to_delete.intersection(node_idxs):
-                self._index_struct.table[keyword] = node_idxs.difference(
-                    node_idxs_to_delete
-                )
-                if not self._index_struct.table[keyword]:
-                    keywords_to_delete.add(keyword)
-
-        for keyword in keywords_to_delete:
-            del self._index_struct.table[keyword]
-
-
-class GPTKeywordTableJIEBAQuery(BaseGPTKeywordTableQuery):
-    """GPT Keyword Table Index JIEBA Query.
-
-    Extracts keywords using JIEBA keyword extractor.
-    Set when `mode="jieba"` in `query` method of `GPTKeywordTableIndex`.
-
-    .. code-block:: python
-
-        response = index.query("<query_str>", mode="jieba")
-
-    See BaseGPTKeywordTableQuery for arguments.
-
-    """
-
-    @classmethod
-    def from_args(
-            cls,
-            index_struct: IS,
-            service_context: ServiceContext,
-            docstore: Optional[BaseDocumentStore] = None,
-            node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
-            verbose: bool = False,
-            # response synthesizer args
-            response_mode: ResponseMode = ResponseMode.DEFAULT,
-            text_qa_template: Optional[QuestionAnswerPrompt] = None,
-            refine_template: Optional[RefinePrompt] = None,
-            simple_template: Optional[SimpleInputPrompt] = None,
-            response_kwargs: Optional[Dict] = None,
-            use_async: bool = False,
-            streaming: bool = False,
-            optimizer: Optional[BaseTokenUsageOptimizer] = None,
-            # class-specific args
-            **kwargs: Any,
-    ) -> "BaseGPTIndexQuery":
-        response_synthesizer = EnhanceResponseSynthesizer.from_args(
-            service_context=service_context,
-            text_qa_template=text_qa_template,
-            refine_template=refine_template,
-            simple_template=simple_template,
-            response_mode=response_mode,
-            response_kwargs=response_kwargs,
-            use_async=use_async,
-            streaming=streaming,
-            optimizer=optimizer,
-        )
-        return cls(
-            index_struct=index_struct,
-            service_context=service_context,
-            response_synthesizer=response_synthesizer,
-            docstore=docstore,
-            node_postprocessors=node_postprocessors,
-            verbose=verbose,
-            **kwargs,
-        )
-
-    def _get_keywords(self, query_str: str) -> List[str]:
-        """Extract keywords."""
-        return list(
-            jieba_extract_keywords(query_str, max_keywords=self.max_keywords_per_query)
-        )

+ 0 - 135
api/core/index/keyword_table_index.py

@@ -1,135 +0,0 @@
-import json
-from typing import List, Optional
-
-from llama_index import ServiceContext, LLMPredictor, OpenAIEmbedding
-from llama_index.data_structs import KeywordTable, Node
-from llama_index.indices.keyword_table.base import BaseGPTKeywordTableIndex
-from llama_index.indices.registry import load_index_struct_from_dict
-
-from core.docstore.dataset_docstore import DatesetDocumentStore
-from core.docstore.empty_docstore import EmptyDocumentStore
-from core.index.index_builder import IndexBuilder
-from core.index.keyword_table.jieba_keyword_table import GPTJIEBAKeywordTableIndex
-from core.llm.llm_builder import LLMBuilder
-from extensions.ext_database import db
-from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment
-
-
-class KeywordTableIndex:
-
-    def __init__(self, dataset: Dataset):
-        self._dataset = dataset
-
-    def add_nodes(self, nodes: List[Node]):
-        llm = LLMBuilder.to_llm(
-            tenant_id=self._dataset.tenant_id,
-            model_name='fake'
-        )
-
-        service_context = ServiceContext.from_defaults(
-            llm_predictor=LLMPredictor(llm=llm),
-            embed_model=OpenAIEmbedding()
-        )
-
-        dataset_keyword_table = self.get_keyword_table()
-        if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
-            index_struct = KeywordTable()
-        else:
-            index_struct_dict = dataset_keyword_table.keyword_table_dict
-            index_struct: KeywordTable = load_index_struct_from_dict(index_struct_dict)
-
-        # create index
-        index = GPTJIEBAKeywordTableIndex(
-            index_struct=index_struct,
-            docstore=EmptyDocumentStore(),
-            service_context=service_context
-        )
-
-        for node in nodes:
-            keywords = index._extract_keywords(node.get_text())
-            self.update_segment_keywords(node.doc_id, list(keywords))
-            index._index_struct.add_node(list(keywords), node)
-
-        index_struct_dict = index.index_struct.to_dict()
-
-        if not dataset_keyword_table:
-            dataset_keyword_table = DatasetKeywordTable(
-                dataset_id=self._dataset.id,
-                keyword_table=json.dumps(index_struct_dict)
-            )
-            db.session.add(dataset_keyword_table)
-        else:
-            dataset_keyword_table.keyword_table = json.dumps(index_struct_dict)
-
-        db.session.commit()
-
-    def del_nodes(self, node_ids: List[str]):
-        llm = LLMBuilder.to_llm(
-            tenant_id=self._dataset.tenant_id,
-            model_name='fake'
-        )
-
-        service_context = ServiceContext.from_defaults(
-            llm_predictor=LLMPredictor(llm=llm),
-            embed_model=OpenAIEmbedding()
-        )
-
-        dataset_keyword_table = self.get_keyword_table()
-        if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
-            return
-        else:
-            index_struct_dict = dataset_keyword_table.keyword_table_dict
-            index_struct: KeywordTable = load_index_struct_from_dict(index_struct_dict)
-
-        # create index
-        index = GPTJIEBAKeywordTableIndex(
-            index_struct=index_struct,
-            docstore=EmptyDocumentStore(),
-            service_context=service_context
-        )
-
-        for node_id in node_ids:
-            index.delete(node_id)
-
-        index_struct_dict = index.index_struct.to_dict()
-
-        if not dataset_keyword_table:
-            dataset_keyword_table = DatasetKeywordTable(
-                dataset_id=self._dataset.id,
-                keyword_table=json.dumps(index_struct_dict)
-            )
-            db.session.add(dataset_keyword_table)
-        else:
-            dataset_keyword_table.keyword_table = json.dumps(index_struct_dict)
-
-        db.session.commit()
-
-    @property
-    def query_index(self) -> Optional[BaseGPTKeywordTableIndex]:
-        docstore = DatesetDocumentStore(
-            dataset=self._dataset,
-            user_id=self._dataset.created_by,
-            embedding_model_name="text-embedding-ada-002"
-        )
-
-        service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
-
-        dataset_keyword_table = self.get_keyword_table()
-        if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
-            return None
-
-        index_struct: KeywordTable = load_index_struct_from_dict(dataset_keyword_table.keyword_table_dict)
-
-        return GPTJIEBAKeywordTableIndex(index_struct=index_struct, docstore=docstore, service_context=service_context)
-
-    def get_keyword_table(self):
-        dataset_keyword_table = self._dataset.dataset_keyword_table
-        if dataset_keyword_table:
-            return dataset_keyword_table
-        return None
-
-    def update_segment_keywords(self, node_id: str, keywords: List[str]):
-        document_segment = db.session.query(DocumentSegment).filter(DocumentSegment.index_node_id == node_id).first()
-        if document_segment:
-            document_segment.keywords = keywords
-            db.session.commit()

+ 33 - 0
api/core/index/keyword_table_index/jieba_keyword_table_handler.py

@@ -0,0 +1,33 @@
+import re
+from typing import Set
+
+import jieba
+from jieba.analyse import default_tfidf
+
+from core.index.keyword_table_index.stopwords import STOPWORDS
+
+
+class JiebaKeywordTableHandler:
+
+    def __init__(self):
+        default_tfidf.stop_words = STOPWORDS
+
+    def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> Set[str]:
+        """Extract keywords with JIEBA tfidf."""
+        keywords = jieba.analyse.extract_tags(
+            sentence=text,
+            topK=max_keywords_per_chunk,
+        )
+
+        return set(self._expand_tokens_with_subtokens(keywords))
+
+    def _expand_tokens_with_subtokens(self, tokens: Set[str]) -> Set[str]:
+        """Get subtokens from a list of tokens., filtering for stopwords."""
+        results = set()
+        for token in tokens:
+            results.add(token)
+            sub_tokens = re.findall(r"\w+", token)
+            if len(sub_tokens) > 1:
+                results.update({w for w in sub_tokens if w not in list(STOPWORDS)})
+
+        return results

+ 238 - 0
api/core/index/keyword_table_index/keyword_table_index.py

@@ -0,0 +1,238 @@
+import json
+from collections import defaultdict
+from typing import Any, List, Optional, Dict
+
+from langchain.schema import Document, BaseRetriever
+from pydantic import BaseModel, Field, Extra
+
+from core.index.base import BaseIndex
+from core.index.keyword_table_index.jieba_keyword_table_handler import JiebaKeywordTableHandler
+from extensions.ext_database import db
+from models.dataset import Dataset, DocumentSegment, DatasetKeywordTable
+
+
+class KeywordTableConfig(BaseModel):
+    max_keywords_per_chunk: int = 10
+
+
+class KeywordTableIndex(BaseIndex):
+    def __init__(self, dataset: Dataset, config: KeywordTableConfig = KeywordTableConfig()):
+        super().__init__(dataset)
+        self._config = config
+
+    def create(self, texts: list[Document], **kwargs) -> BaseIndex:
+        keyword_table_handler = JiebaKeywordTableHandler()
+        keyword_table = {}
+        for text in texts:
+            keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
+            self._update_segment_keywords(text.metadata['doc_id'], list(keywords))
+            keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
+
+        dataset_keyword_table = DatasetKeywordTable(
+            dataset_id=self.dataset.id,
+            keyword_table=json.dumps({
+                '__type__': 'keyword_table',
+                '__data__': {
+                    "index_id": self.dataset.id,
+                    "summary": None,
+                    "table": {}
+                }
+            }, cls=SetEncoder)
+        )
+        db.session.add(dataset_keyword_table)
+        db.session.commit()
+
+        self._save_dataset_keyword_table(keyword_table)
+
+        return self
+
+    def add_texts(self, texts: list[Document], **kwargs):
+        keyword_table_handler = JiebaKeywordTableHandler()
+
+        keyword_table = self._get_dataset_keyword_table()
+        for text in texts:
+            keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
+            self._update_segment_keywords(text.metadata['doc_id'], list(keywords))
+            keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
+
+        self._save_dataset_keyword_table(keyword_table)
+
+    def text_exists(self, id: str) -> bool:
+        keyword_table = self._get_dataset_keyword_table()
+        return id in set.union(*keyword_table.values())
+
+    def delete_by_ids(self, ids: list[str]) -> None:
+        keyword_table = self._get_dataset_keyword_table()
+        keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
+
+        self._save_dataset_keyword_table(keyword_table)
+
+    def delete_by_document_id(self, document_id: str):
+        # get segment ids by document_id
+        segments = db.session.query(DocumentSegment).filter(
+            DocumentSegment.dataset_id == self.dataset.id,
+            DocumentSegment.document_id == document_id
+        ).all()
+
+        ids = [segment.id for segment in segments]
+
+        keyword_table = self._get_dataset_keyword_table()
+        keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
+
+        self._save_dataset_keyword_table(keyword_table)
+
+    def get_retriever(self, **kwargs: Any) -> BaseRetriever:
+        return KeywordTableRetriever(index=self, **kwargs)
+
+    def search(
+            self, query: str,
+            **kwargs: Any
+    ) -> List[Document]:
+        keyword_table = self._get_dataset_keyword_table()
+
+        search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
+        k = search_kwargs.get('k') if search_kwargs.get('k') else 4
+
+        sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k)
+
+        documents = []
+        for chunk_index in sorted_chunk_indices:
+            segment = db.session.query(DocumentSegment).filter(
+                DocumentSegment.dataset_id == self.dataset.id,
+                DocumentSegment.index_node_id == chunk_index
+            ).first()
+
+            if segment:
+                documents.append(Document(
+                    page_content=segment.content,
+                    metadata={
+                        "doc_id": chunk_index,
+                        "document_id": segment.document_id,
+                        "dataset_id": segment.dataset_id,
+                    }
+                ))
+
+        return documents
+
+    def delete(self) -> None:
+        dataset_keyword_table = self.dataset.dataset_keyword_table
+        if dataset_keyword_table:
+            db.session.delete(dataset_keyword_table)
+            db.session.commit()
+
+    def _save_dataset_keyword_table(self, keyword_table):
+        keyword_table_dict = {
+            '__type__': 'keyword_table',
+            '__data__': {
+                "index_id": self.dataset.id,
+                "summary": None,
+                "table": keyword_table
+            }
+        }
+        self.dataset.dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder)
+        db.session.commit()
+
+    def _get_dataset_keyword_table(self) -> Optional[dict]:
+        dataset_keyword_table = self.dataset.dataset_keyword_table
+        if dataset_keyword_table:
+            if dataset_keyword_table.keyword_table_dict:
+                return dataset_keyword_table.keyword_table_dict['__data__']['table']
+        else:
+            dataset_keyword_table = DatasetKeywordTable(
+                dataset_id=self.dataset.id,
+                keyword_table=json.dumps({
+                    '__type__': 'keyword_table',
+                    '__data__': {
+                        "index_id": self.dataset.id,
+                        "summary": None,
+                        "table": {}
+                    }
+                }, cls=SetEncoder)
+            )
+            db.session.add(dataset_keyword_table)
+            db.session.commit()
+
+        return {}
+
+    def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict:
+        for keyword in keywords:
+            if keyword not in keyword_table:
+                keyword_table[keyword] = set()
+            keyword_table[keyword].add(id)
+        return keyword_table
+
+    def _delete_ids_from_keyword_table(self, keyword_table: dict, ids: list[str]) -> dict:
+        # get set of ids that correspond to node
+        node_idxs_to_delete = set(ids)
+
+        # delete node_idxs from keyword to node idxs mapping
+        keywords_to_delete = set()
+        for keyword, node_idxs in keyword_table.items():
+            if node_idxs_to_delete.intersection(node_idxs):
+                keyword_table[keyword] = node_idxs.difference(
+                    node_idxs_to_delete
+                )
+                if not keyword_table[keyword]:
+                    keywords_to_delete.add(keyword)
+
+        for keyword in keywords_to_delete:
+            del keyword_table[keyword]
+
+        return keyword_table
+
+    def _retrieve_ids_by_query(self, keyword_table: dict, query: str, k: int = 4):
+        keyword_table_handler = JiebaKeywordTableHandler()
+        keywords = keyword_table_handler.extract_keywords(query)
+
+        # go through text chunks in order of most matching keywords
+        chunk_indices_count: Dict[str, int] = defaultdict(int)
+        keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())]
+        for keyword in keywords:
+            for node_id in keyword_table[keyword]:
+                chunk_indices_count[node_id] += 1
+
+        sorted_chunk_indices = sorted(
+            list(chunk_indices_count.keys()),
+            key=lambda x: chunk_indices_count[x],
+            reverse=True,
+        )
+
+        return sorted_chunk_indices[: k]
+
+    def _update_segment_keywords(self, node_id: str, keywords: List[str]):
+        document_segment = db.session.query(DocumentSegment).filter(DocumentSegment.index_node_id == node_id).first()
+        if document_segment:
+            document_segment.keywords = keywords
+            db.session.commit()
+
+
+class KeywordTableRetriever(BaseRetriever, BaseModel):
+    index: KeywordTableIndex
+    search_kwargs: dict = Field(default_factory=dict)
+
+    class Config:
+        """Configuration for this pydantic object."""
+
+        extra = Extra.forbid
+        arbitrary_types_allowed = True
+
+    def get_relevant_documents(self, query: str) -> List[Document]:
+        """Get documents relevant for a query.
+
+        Args:
+            query: string to find relevant documents for
+
+        Returns:
+            List of relevant documents
+        """
+        return self.index.search(query, **self.search_kwargs)
+
+    async def aget_relevant_documents(self, query: str) -> List[Document]:
+        raise NotImplementedError("KeywordTableRetriever does not support async")
+
+
+class SetEncoder(json.JSONEncoder):
+    def default(self, obj):
+        if isinstance(obj, set):
+            return list(obj)
+        return super().default(obj)

+ 0 - 0
api/core/index/keyword_table/stopwords.py → api/core/index/keyword_table_index/stopwords.py


+ 0 - 79
api/core/index/query/synthesizer.py

@@ -1,79 +0,0 @@
-from typing import (
-    Any,
-    Dict,
-    Optional, Sequence,
-)
-
-from llama_index.indices.response.response_synthesis import ResponseSynthesizer
-from llama_index.indices.response.response_builder import ResponseMode, BaseResponseBuilder, get_response_builder
-from llama_index.indices.service_context import ServiceContext
-from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
-from llama_index.prompts.prompts import (
-    QuestionAnswerPrompt,
-    RefinePrompt,
-    SimpleInputPrompt,
-)
-from llama_index.types import RESPONSE_TEXT_TYPE
-
-
-class EnhanceResponseSynthesizer(ResponseSynthesizer):
-    @classmethod
-    def from_args(
-            cls,
-            service_context: ServiceContext,
-            streaming: bool = False,
-            use_async: bool = False,
-            text_qa_template: Optional[QuestionAnswerPrompt] = None,
-            refine_template: Optional[RefinePrompt] = None,
-            simple_template: Optional[SimpleInputPrompt] = None,
-            response_mode: ResponseMode = ResponseMode.DEFAULT,
-            response_kwargs: Optional[Dict] = None,
-            optimizer: Optional[BaseTokenUsageOptimizer] = None,
-    ) -> "ResponseSynthesizer":
-        response_builder: Optional[BaseResponseBuilder] = None
-        if response_mode != ResponseMode.NO_TEXT:
-            if response_mode == 'no_synthesizer':
-                response_builder = NoSynthesizer(
-                    service_context=service_context,
-                    simple_template=simple_template,
-                    streaming=streaming,
-                )
-            else:
-                response_builder = get_response_builder(
-                    service_context,
-                    text_qa_template,
-                    refine_template,
-                    simple_template,
-                    response_mode,
-                    use_async=use_async,
-                    streaming=streaming,
-                )
-        return cls(response_builder, response_mode, response_kwargs, optimizer)
-
-
-class NoSynthesizer(BaseResponseBuilder):
-    def __init__(
-            self,
-            service_context: ServiceContext,
-            simple_template: Optional[SimpleInputPrompt] = None,
-            streaming: bool = False,
-    ) -> None:
-        super().__init__(service_context, streaming)
-
-    async def aget_response(
-            self,
-            query_str: str,
-            text_chunks: Sequence[str],
-            prev_response: Optional[str] = None,
-            **response_kwargs: Any,
-    ) -> RESPONSE_TEXT_TYPE:
-        return "\n".join(text_chunks)
-
-    def get_response(
-            self,
-            query_str: str,
-            text_chunks: Sequence[str],
-            prev_response: Optional[str] = None,
-            **response_kwargs: Any,
-    ) -> RESPONSE_TEXT_TYPE:
-        return "\n".join(text_chunks)

+ 0 - 22
api/core/index/readers/html_parser.py

@@ -1,22 +0,0 @@
-from pathlib import Path
-from typing import Dict
-
-from bs4 import BeautifulSoup
-from llama_index.readers.file.base_parser import BaseParser
-
-
-class HTMLParser(BaseParser):
-    """HTML parser."""
-
-    def _init_parser(self) -> Dict:
-        """Init parser."""
-        return {}
-
-    def parse_file(self, file: Path, errors: str = "ignore") -> str:
-        """Parse file."""
-        with open(file, "rb") as fp:
-            soup = BeautifulSoup(fp, 'html.parser')
-            text = soup.get_text()
-            text = text.strip() if text else ''
-
-        return text

+ 0 - 111
api/core/index/readers/markdown_parser.py

@@ -1,111 +0,0 @@
-"""Markdown parser.
-
-Contains parser for md files.
-
-"""
-import re
-from pathlib import Path
-from typing import Any, Dict, List, Optional, Tuple, Union, cast
-
-from llama_index.readers.file.base_parser import BaseParser
-
-
-class MarkdownParser(BaseParser):
-    """Markdown parser.
-
-    Extract text from markdown files.
-    Returns dictionary with keys as headers and values as the text between headers.
-
-    """
-
-    def __init__(
-        self,
-        *args: Any,
-        remove_hyperlinks: bool = True,
-        remove_images: bool = True,
-        **kwargs: Any,
-    ) -> None:
-        """Init params."""
-        super().__init__(*args, **kwargs)
-        self._remove_hyperlinks = remove_hyperlinks
-        self._remove_images = remove_images
-
-    def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]:
-        """Convert a markdown file to a dictionary.
-
-        The keys are the headers and the values are the text under each header.
-
-        """
-        markdown_tups: List[Tuple[Optional[str], str]] = []
-        lines = markdown_text.split("\n")
-
-        current_header = None
-        current_text = ""
-
-        for line in lines:
-            header_match = re.match(r"^#+\s", line)
-            if header_match:
-                if current_header is not None:
-                    markdown_tups.append((current_header, current_text))
-
-                current_header = line
-                current_text = ""
-            else:
-                current_text += line + "\n"
-        markdown_tups.append((current_header, current_text))
-
-        if current_header is not None:
-            # pass linting, assert keys are defined
-            markdown_tups = [
-                (re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value))
-                for key, value in markdown_tups
-            ]
-        else:
-            markdown_tups = [
-                (key, re.sub("\n", "", value)) for key, value in markdown_tups
-            ]
-
-        return markdown_tups
-
-    def remove_images(self, content: str) -> str:
-        """Get a dictionary of a markdown file from its path."""
-        pattern = r"!{1}\[\[(.*)\]\]"
-        content = re.sub(pattern, "", content)
-        return content
-
-    def remove_hyperlinks(self, content: str) -> str:
-        """Get a dictionary of a markdown file from its path."""
-        pattern = r"\[(.*?)\]\((.*?)\)"
-        content = re.sub(pattern, r"\1", content)
-        return content
-
-    def _init_parser(self) -> Dict:
-        """Initialize the parser with the config."""
-        return {}
-
-    def parse_tups(
-        self, filepath: Path, errors: str = "ignore"
-    ) -> List[Tuple[Optional[str], str]]:
-        """Parse file into tuples."""
-        with open(filepath, "r", encoding="utf-8") as f:
-            content = f.read()
-        if self._remove_hyperlinks:
-            content = self.remove_hyperlinks(content)
-        if self._remove_images:
-            content = self.remove_images(content)
-        markdown_tups = self.markdown_to_tups(content)
-        return markdown_tups
-
-    def parse_file(
-        self, filepath: Path, errors: str = "ignore"
-    ) -> Union[str, List[str]]:
-        """Parse file into string."""
-        tups = self.parse_tups(filepath, errors=errors)
-        results = []
-        # TODO: don't include headers right now
-        for header, value in tups:
-            if header is None:
-                results.append(value)
-            else:
-                results.append(f"\n\n{header}\n{value}")
-        return results

+ 0 - 56
api/core/index/readers/pdf_parser.py

@@ -1,56 +0,0 @@
-from pathlib import Path
-from typing import Dict
-
-from flask import current_app
-from llama_index.readers.file.base_parser import BaseParser
-from pypdf import PdfReader
-
-from extensions.ext_storage import storage
-from models.model import UploadFile
-
-
-class PDFParser(BaseParser):
-    """PDF parser."""
-
-    def _init_parser(self) -> Dict:
-        """Init parser."""
-        return {}
-
-    def parse_file(self, file: Path, errors: str = "ignore") -> str:
-        """Parse file."""
-        if not current_app.config.get('PDF_PREVIEW', True):
-            return ''
-
-        plaintext_file_key = ''
-        plaintext_file_exists = False
-        if self._parser_config and 'upload_file' in self._parser_config and self._parser_config['upload_file']:
-            upload_file: UploadFile = self._parser_config['upload_file']
-            if upload_file.hash:
-                plaintext_file_key = 'upload_files/' + upload_file.tenant_id + '/' + upload_file.hash + '.plaintext'
-                try:
-                    text = storage.load(plaintext_file_key).decode('utf-8')
-                    plaintext_file_exists = True
-                    return text
-                except FileNotFoundError:
-                    pass
-
-        text_list = []
-        with open(file, "rb") as fp:
-            # Create a PDF object
-            pdf = PdfReader(fp)
-
-            # Get the number of pages in the PDF document
-            num_pages = len(pdf.pages)
-
-            # Iterate over every page
-            for page in range(num_pages):
-                # Extract the text from the page
-                page_text = pdf.pages[page].extract_text()
-                text_list.append(page_text)
-        text = "\n".join(text_list)
-
-        # save plaintext file for caching
-        if not plaintext_file_exists and plaintext_file_key:
-            storage.save(plaintext_file_key, text.encode('utf-8'))
-
-        return text

+ 0 - 33
api/core/index/readers/xlsx_parser.py

@@ -1,33 +0,0 @@
-from pathlib import Path
-import json
-from typing import Dict
-from openpyxl import load_workbook
-
-from llama_index.readers.file.base_parser import BaseParser
-from flask import current_app
-
-
-class XLSXParser(BaseParser):
-    """XLSX parser."""
-
-    def _init_parser(self) -> Dict:
-        """Init parser"""
-        return {}
-
-    def parse_file(self, file: Path, errors: str = "ignore") -> str:
-        data = []
-        keys = []
-        with open(file, "r") as fp:
-            wb = load_workbook(filename=file, read_only=True)
-            # loop over all sheets
-            for sheet in wb:
-                for row in sheet.iter_rows(values_only=True):
-                    if all(v is None for v in row):
-                        continue
-                    if keys == []:
-                        keys = list(map(str, row))
-                    else:
-                        row_dict = dict(zip(keys, row))
-                        row_dict = {k: v for k, v in row_dict.items() if v}
-                        data.append(json.dumps(row_dict, ensure_ascii=False))
-        return '\n\n'.join(data)

+ 0 - 136
api/core/index/vector_index.py

@@ -1,136 +0,0 @@
-import json
-import logging
-from typing import List, Optional
-
-from llama_index.data_structs import Node
-from requests import ReadTimeout
-from sqlalchemy.exc import IntegrityError
-from tenacity import retry, stop_after_attempt, retry_if_exception_type
-
-from core.index.index_builder import IndexBuilder
-from core.vector_store.base import BaseGPTVectorStoreIndex
-from extensions.ext_vector_store import vector_store
-from extensions.ext_database import db
-from models.dataset import Dataset, Embedding
-
-
-class VectorIndex:
-
-    def __init__(self, dataset: Dataset):
-        self._dataset = dataset
-
-    def add_nodes(self, nodes: List[Node], duplicate_check: bool = False):
-        if not self._dataset.index_struct_dict:
-            index_id = "Vector_index_" + self._dataset.id.replace("-", "_")
-            self._dataset.index_struct = json.dumps(vector_store.to_index_struct(index_id))
-            db.session.commit()
-
-        service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
-
-        index = vector_store.get_index(
-            service_context=service_context,
-            index_struct=self._dataset.index_struct_dict
-        )
-
-        if duplicate_check:
-            nodes = self._filter_duplicate_nodes(index, nodes)
-
-        embedding_queue_nodes = []
-        embedded_nodes = []
-        for node in nodes:
-            node_hash = node.doc_hash
-
-            # if node hash in cached embedding tables, use cached embedding
-            embedding = db.session.query(Embedding).filter_by(hash=node_hash).first()
-            if embedding:
-                node.embedding = embedding.get_embedding()
-                embedded_nodes.append(node)
-            else:
-                embedding_queue_nodes.append(node)
-
-        if embedding_queue_nodes:
-            embedding_results = index._get_node_embedding_results(
-                embedding_queue_nodes,
-                set(),
-            )
-
-            # pre embed nodes for cached embedding
-            for embedding_result in embedding_results:
-                node = embedding_result.node
-                node.embedding = embedding_result.embedding
-
-                try:
-                    embedding = Embedding(hash=node.doc_hash)
-                    embedding.set_embedding(node.embedding)
-                    db.session.add(embedding)
-                    db.session.commit()
-                except IntegrityError:
-                    db.session.rollback()
-                    continue
-                except:
-                    logging.exception('Failed to add embedding to db')
-                    continue
-
-                embedded_nodes.append(node)
-
-        self.index_insert_nodes(index, embedded_nodes)
-
-    @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
-    def index_insert_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]):
-        index.insert_nodes(nodes)
-
-    def del_nodes(self, node_ids: List[str]):
-        if not self._dataset.index_struct_dict:
-            return
-
-        service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id)
-
-        index = vector_store.get_index(
-            service_context=service_context,
-            index_struct=self._dataset.index_struct_dict
-        )
-
-        for node_id in node_ids:
-            self.index_delete_node(index, node_id)
-
-    @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
-    def index_delete_node(self, index: BaseGPTVectorStoreIndex, node_id: str):
-        index.delete_node(node_id)
-
-    def del_doc(self, doc_id: str):
-        if not self._dataset.index_struct_dict:
-            return
-
-        service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id)
-
-        index = vector_store.get_index(
-            service_context=service_context,
-            index_struct=self._dataset.index_struct_dict
-        )
-
-        self.index_delete_doc(index, doc_id)
-
-    @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
-    def index_delete_doc(self, index: BaseGPTVectorStoreIndex, doc_id: str):
-        index.delete(doc_id)
-
-    @property
-    def query_index(self) -> Optional[BaseGPTVectorStoreIndex]:
-        if not self._dataset.index_struct_dict:
-            return None
-
-        service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
-
-        return vector_store.get_index(
-            service_context=service_context,
-            index_struct=self._dataset.index_struct_dict
-        )
-
-    def _filter_duplicate_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]) -> List[Node]:
-        for node in nodes:
-            node_id = node.doc_id
-            exists_duplicate_node = index.exists_by_node_id(node_id)
-            if exists_duplicate_node:
-                nodes.remove(node)
-
-        return nodes

+ 175 - 0
api/core/index/vector_index/base.py

@@ -0,0 +1,175 @@
+import json
+import logging
+from abc import abstractmethod
+from typing import List, Any, cast
+
+from langchain.embeddings.base import Embeddings
+from langchain.schema import Document, BaseRetriever
+from langchain.vectorstores import VectorStore
+from weaviate import UnexpectedStatusCodeException
+
+from core.index.base import BaseIndex
+from extensions.ext_database import db
+from models.dataset import Dataset, DocumentSegment
+from models.dataset import Document as DatasetDocument
+
+
+class BaseVectorIndex(BaseIndex):
+    
+    def __init__(self, dataset: Dataset, embeddings: Embeddings):
+        super().__init__(dataset)
+        self._embeddings = embeddings
+        self._vector_store = None
+        
+    def get_type(self) -> str:
+        raise NotImplementedError
+
+    @abstractmethod
+    def get_index_name(self, dataset: Dataset) -> str:
+        raise NotImplementedError
+
+    @abstractmethod
+    def to_index_struct(self) -> dict:
+        raise NotImplementedError
+
+    @abstractmethod
+    def _get_vector_store(self) -> VectorStore:
+        raise NotImplementedError
+
+    @abstractmethod
+    def _get_vector_store_class(self) -> type:
+        raise NotImplementedError
+
+    def search(
+            self, query: str,
+            **kwargs: Any
+    ) -> List[Document]:
+        vector_store = self._get_vector_store()
+        vector_store = cast(self._get_vector_store_class(), vector_store)
+
+        search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity'
+        search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
+
+        if search_type == 'similarity_score_threshold':
+            score_threshold = search_kwargs.get("score_threshold")
+            if (score_threshold is None) or (not isinstance(score_threshold, float)):
+                search_kwargs['score_threshold'] = .0
+
+            docs_with_similarity = vector_store.similarity_search_with_relevance_scores(
+                query, **search_kwargs
+            )
+
+            docs = []
+            for doc, similarity in docs_with_similarity:
+                doc.metadata['score'] = similarity
+                docs.append(doc)
+
+            return docs
+
+        # similarity k
+        # mmr k, fetch_k, lambda_mult
+        # similarity_score_threshold k
+        return vector_store.as_retriever(
+            search_type=search_type,
+            search_kwargs=search_kwargs
+        ).get_relevant_documents(query)
+
+    def get_retriever(self, **kwargs: Any) -> BaseRetriever:
+        vector_store = self._get_vector_store()
+        vector_store = cast(self._get_vector_store_class(), vector_store)
+
+        return vector_store.as_retriever(**kwargs)
+
+    def add_texts(self, texts: list[Document], **kwargs):
+        if self._is_origin():
+            self.recreate_dataset(self.dataset)
+
+        vector_store = self._get_vector_store()
+        vector_store = cast(self._get_vector_store_class(), vector_store)
+
+        if kwargs.get('duplicate_check', False):
+            texts = self._filter_duplicate_texts(texts)
+
+        uuids = self._get_uuids(texts)
+        vector_store.add_documents(texts, uuids=uuids)
+
+    def text_exists(self, id: str) -> bool:
+        vector_store = self._get_vector_store()
+        vector_store = cast(self._get_vector_store_class(), vector_store)
+
+        return vector_store.text_exists(id)
+
+    def delete_by_ids(self, ids: list[str]) -> None:
+        if self._is_origin():
+            self.recreate_dataset(self.dataset)
+            return
+
+        vector_store = self._get_vector_store()
+        vector_store = cast(self._get_vector_store_class(), vector_store)
+
+        for node_id in ids:
+            vector_store.del_text(node_id)
+
+    def delete(self) -> None:
+        vector_store = self._get_vector_store()
+        vector_store = cast(self._get_vector_store_class(), vector_store)
+
+        vector_store.delete()
+
+    def _is_origin(self):
+        return False
+
+    def recreate_dataset(self, dataset: Dataset):
+        logging.info(f"Recreating dataset {dataset.id}")
+
+        try:
+            self.delete()
+        except UnexpectedStatusCodeException as e:
+            if e.status_code != 400:
+                # 400 means index not exists
+                raise e
+
+        dataset_documents = db.session.query(DatasetDocument).filter(
+            DatasetDocument.dataset_id == dataset.id,
+            DatasetDocument.indexing_status == 'completed',
+            DatasetDocument.enabled == True,
+            DatasetDocument.archived == False,
+        ).all()
+
+        documents = []
+        for dataset_document in dataset_documents:
+            segments = db.session.query(DocumentSegment).filter(
+                DocumentSegment.document_id == dataset_document.id,
+                DocumentSegment.status == 'completed',
+                DocumentSegment.enabled == True
+            ).all()
+            
+            for segment in segments:
+                document = Document(
+                    page_content=segment.content,
+                    metadata={
+                        "doc_id": segment.index_node_id,
+                        "doc_hash": segment.index_node_hash,
+                        "document_id": segment.document_id,
+                        "dataset_id": segment.dataset_id,
+                    }
+                )
+
+                documents.append(document)
+
+        origin_index_struct = self.dataset.index_struct
+        self.dataset.index_struct = None
+
+        if documents:
+            try:
+                self.create(documents)
+            except Exception as e:
+                self.dataset.index_struct = origin_index_struct
+                raise e
+
+            dataset.index_struct = json.dumps(self.to_index_struct())
+
+        db.session.commit()
+
+        self.dataset = dataset
+        logging.info(f"Dataset {dataset.id} recreate successfully.")

+ 116 - 0
api/core/index/vector_index/qdrant_vector_index.py

@@ -0,0 +1,116 @@
+import os
+from typing import Optional, Any, List, cast
+
+import qdrant_client
+from langchain.embeddings.base import Embeddings
+from langchain.schema import Document, BaseRetriever
+from langchain.vectorstores import VectorStore
+from pydantic import BaseModel
+
+from core.index.base import BaseIndex
+from core.index.vector_index.base import BaseVectorIndex
+from core.vector_store.qdrant_vector_store import QdrantVectorStore
+from models.dataset import Dataset
+
+
+class QdrantConfig(BaseModel):
+    endpoint: str
+    api_key: Optional[str]
+    root_path: Optional[str]
+    
+    def to_qdrant_params(self):
+        if self.endpoint and self.endpoint.startswith('path:'):
+            path = self.endpoint.replace('path:', '')
+            if not os.path.isabs(path):
+                path = os.path.join(self.root_path, path)
+
+            return {
+                'path': path
+            }
+        else:
+            return {
+                'url': self.endpoint,
+                'api_key': self.api_key,
+            }
+
+
+class QdrantVectorIndex(BaseVectorIndex):
+    def __init__(self, dataset: Dataset, config: QdrantConfig, embeddings: Embeddings):
+        super().__init__(dataset, embeddings)
+        self._client_config = config
+
+    def get_type(self) -> str:
+        return 'qdrant'
+
+    def get_index_name(self, dataset: Dataset) -> str:
+        if self.dataset.index_struct_dict:
+            return self.dataset.index_struct_dict['vector_store']['collection_name']
+
+        dataset_id = dataset.id
+        return "Index_" + dataset_id.replace("-", "_")
+
+    def to_index_struct(self) -> dict:
+        return {
+            "type": self.get_type(),
+            "vector_store": {"collection_name": self.get_index_name(self.dataset)}
+        }
+
+    def create(self, texts: list[Document], **kwargs) -> BaseIndex:
+        uuids = self._get_uuids(texts)
+        self._vector_store = QdrantVectorStore.from_documents(
+            texts,
+            self._embeddings,
+            collection_name=self.get_index_name(self.dataset),
+            ids=uuids,
+            content_payload_key='text',
+            **self._client_config.to_qdrant_params()
+        )
+
+        return self
+
+    def _get_vector_store(self) -> VectorStore:
+        """Only for created index."""
+        if self._vector_store:
+            return self._vector_store
+        
+        client = qdrant_client.QdrantClient(
+            **self._client_config.to_qdrant_params()
+        )
+
+        return QdrantVectorStore(
+            client=client,
+            collection_name=self.get_index_name(self.dataset),
+            embeddings=self._embeddings,
+            content_payload_key='text'
+        )
+
+    def _get_vector_store_class(self) -> type:
+        return QdrantVectorStore
+
+    def delete_by_document_id(self, document_id: str):
+        if self._is_origin():
+            self.recreate_dataset(self.dataset)
+            return
+
+        vector_store = self._get_vector_store()
+        vector_store = cast(self._get_vector_store_class(), vector_store)
+
+        from qdrant_client.http import models
+
+        vector_store.del_texts(models.Filter(
+            must=[
+                models.FieldCondition(
+                    key="metadata.document_id",
+                    match=models.MatchValue(value=document_id),
+                ),
+            ],
+        ))
+
+    def _is_origin(self):
+        if self.dataset.index_struct_dict:
+            class_prefix: str = self.dataset.index_struct_dict['vector_store']['collection_name']
+            if class_prefix.startswith('Vector_'):
+                # original class_prefix
+                return True
+
+        return False

+ 69 - 0
api/core/index/vector_index/vector_index.py

@@ -0,0 +1,69 @@
+import json
+
+from flask import current_app
+from langchain.embeddings.base import Embeddings
+
+from core.index.vector_index.base import BaseVectorIndex
+from extensions.ext_database import db
+from models.dataset import Dataset, Document
+
+
+class VectorIndex:
+    def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings):
+        self._dataset = dataset
+        self._embeddings = embeddings
+        self._vector_index = self._init_vector_index(dataset, config, embeddings)
+
+    def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings) -> BaseVectorIndex:
+        vector_type = config.get('VECTOR_STORE')
+
+        if self._dataset.index_struct_dict:
+            vector_type = self._dataset.index_struct_dict['type']
+
+        if not vector_type:
+            raise ValueError(f"Vector store must be specified.")
+
+        if vector_type == "weaviate":
+            from core.index.vector_index.weaviate_vector_index import WeaviateVectorIndex, WeaviateConfig
+
+            return WeaviateVectorIndex(
+                dataset=dataset,
+                config=WeaviateConfig(
+                    endpoint=config.get('WEAVIATE_ENDPOINT'),
+                    api_key=config.get('WEAVIATE_API_KEY'),
+                    batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
+                ),
+                embeddings=embeddings
+            )
+        elif vector_type == "qdrant":
+            from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
+
+            return QdrantVectorIndex(
+                dataset=dataset,
+                config=QdrantConfig(
+                    endpoint=config.get('QDRANT_URL'),
+                    api_key=config.get('QDRANT_API_KEY'),
+                    root_path=current_app.root_path
+                ),
+                embeddings=embeddings
+            )
+        else:
+            raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
+
+    def add_texts(self, texts: list[Document], **kwargs):
+        if not self._dataset.index_struct_dict:
+            self._vector_index.create(texts, **kwargs)
+            self._dataset.index_struct = json.dumps(self._vector_index.to_index_struct())
+            db.session.commit()
+            return
+
+        self._vector_index.add_texts(texts, **kwargs)
+
+    def __getattr__(self, name):
+        if self._vector_index is not None:
+            method = getattr(self._vector_index, name)
+            if callable(method):
+                return method
+
+        raise AttributeError(f"'VectorIndex' object has no attribute '{name}'")
+

+ 132 - 0
api/core/index/vector_index/weaviate_vector_index.py

@@ -0,0 +1,132 @@
+from typing import Optional, cast
+
+import weaviate
+from langchain.embeddings.base import Embeddings
+from langchain.schema import Document, BaseRetriever
+from langchain.vectorstores import VectorStore
+from pydantic import BaseModel, root_validator
+
+from core.index.base import BaseIndex
+from core.index.vector_index.base import BaseVectorIndex
+from core.vector_store.weaviate_vector_store import WeaviateVectorStore
+from models.dataset import Dataset
+
+
+class WeaviateConfig(BaseModel):
+    endpoint: str
+    api_key: Optional[str]
+    batch_size: int = 100
+
+    @root_validator()
+    def validate_config(cls, values: dict) -> dict:
+        if not values['endpoint']:
+            raise ValueError("config WEAVIATE_ENDPOINT is required")
+        return values
+
+
+class WeaviateVectorIndex(BaseVectorIndex):
+    def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings):
+        super().__init__(dataset, embeddings)
+        self._client = self._init_client(config)
+
+    def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
+        auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
+
+        weaviate.connect.connection.has_grpc = False
+
+        client = weaviate.Client(
+            url=config.endpoint,
+            auth_client_secret=auth_config,
+            timeout_config=(5, 60),
+            startup_period=None
+        )
+
+        client.batch.configure(
+            # `batch_size` takes an `int` value to enable auto-batching
+            # (`None` is used for manual batching)
+            batch_size=config.batch_size,
+            # dynamically update the `batch_size` based on import speed
+            dynamic=True,
+            # `timeout_retries` takes an `int` value to retry on time outs
+            timeout_retries=3,
+        )
+
+        return client
+
+    def get_type(self) -> str:
+        return 'weaviate'
+
+    def get_index_name(self, dataset: Dataset) -> str:
+        if self.dataset.index_struct_dict:
+            class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
+            if not class_prefix.endswith('_Node'):
+                # original class_prefix
+                class_prefix += '_Node'
+
+            return class_prefix
+
+        dataset_id = dataset.id
+        return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
+
+    def to_index_struct(self) -> dict:
+        return {
+            "type": self.get_type(),
+            "vector_store": {"class_prefix": self.get_index_name(self.dataset)}
+        }
+
+    def create(self, texts: list[Document], **kwargs) -> BaseIndex:
+        uuids = self._get_uuids(texts)
+        self._vector_store = WeaviateVectorStore.from_documents(
+            texts,
+            self._embeddings,
+            client=self._client,
+            index_name=self.get_index_name(self.dataset),
+            uuids=uuids,
+            by_text=False
+        )
+
+        return self
+
+    def _get_vector_store(self) -> VectorStore:
+        """Only for created index."""
+        if self._vector_store:
+            return self._vector_store
+
+        attributes = ['doc_id', 'dataset_id', 'document_id']
+        if self._is_origin():
+            attributes = ['doc_id']
+
+        return WeaviateVectorStore(
+            client=self._client,
+            index_name=self.get_index_name(self.dataset),
+            text_key='text',
+            embedding=self._embeddings,
+            attributes=attributes,
+            by_text=False
+        )
+
+    def _get_vector_store_class(self) -> type:
+        return WeaviateVectorStore
+
+    def delete_by_document_id(self, document_id: str):
+        if self._is_origin():
+            self.recreate_dataset(self.dataset)
+            return
+
+        vector_store = self._get_vector_store()
+        vector_store = cast(self._get_vector_store_class(), vector_store)
+
+        vector_store.del_texts({
+            "operator": "Equal",
+            "path": ["document_id"],
+            "valueText": document_id
+        })
+
+    def _is_origin(self):
+        if self.dataset.index_struct_dict:
+            class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
+            if not class_prefix.endswith('_Node'):
+                # original class_prefix
+                return True
+
+        return False

+ 262 - 283
api/core/indexing_runner.py

@@ -1,35 +1,34 @@
 import datetime
 import json
+import logging
 import re
-import tempfile
 import time
-from pathlib import Path
-from typing import Optional, List
+import uuid
+from typing import Optional, List, cast
 
+from flask import current_app
 from flask_login import current_user
-from langchain.text_splitter import RecursiveCharacterTextSplitter
+from langchain.embeddings import OpenAIEmbeddings
+from langchain.schema import Document
+from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
 
-from llama_index import SimpleDirectoryReader
-from llama_index.data_structs import Node
-from llama_index.data_structs.node_v2 import DocumentRelationship
-from llama_index.node_parser import SimpleNodeParser, NodeParser
-from llama_index.readers.file.base import DEFAULT_FILE_EXTRACTOR
-from llama_index.readers.file.markdown_parser import MarkdownParser
-
-from core.data_source.notion import NotionPageReader
-from core.index.readers.xlsx_parser import XLSXParser
+from core.data_loader.file_extractor import FileExtractor
+from core.data_loader.loader.notion import NotionLoader
 from core.docstore.dataset_docstore import DatesetDocumentStore
-from core.index.keyword_table_index import KeywordTableIndex
-from core.index.readers.html_parser import HTMLParser
-from core.index.readers.markdown_parser import MarkdownParser
-from core.index.readers.pdf_parser import PDFParser
-from core.index.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
-from core.index.vector_index import VectorIndex
+from core.embedding.cached_embedding import CacheEmbedding
+from core.index.index import IndexBuilder
+from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
+from core.index.vector_index.vector_index import VectorIndex
+from core.llm.error import ProviderTokenNotInitError
+from core.llm.llm_builder import LLMBuilder
+from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
 from core.llm.token_calculator import TokenCalculator
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from extensions.ext_storage import storage
-from models.dataset import Document, Dataset, DocumentSegment, DatasetProcessRule
+from libs import helper
+from models.dataset import Document as DatasetDocument
+from models.dataset import Dataset, DocumentSegment, DatasetProcessRule
 from models.model import UploadFile
 from models.source import DataSourceBinding
 
@@ -40,135 +39,171 @@ class IndexingRunner:
         self.storage = storage
         self.embedding_model_name = embedding_model_name
 
-    def run(self, documents: List[Document]):
+    def run(self, dataset_documents: List[DatasetDocument]):
         """Run the indexing process."""
-        for document in documents:
+        for dataset_document in dataset_documents:
+            try:
+                # get dataset
+                dataset = Dataset.query.filter_by(
+                    id=dataset_document.dataset_id
+                ).first()
+
+                if not dataset:
+                    raise ValueError("no dataset found")
+
+                # load file
+                text_docs = self._load_data(dataset_document)
+
+                # get the process rule
+                processing_rule = db.session.query(DatasetProcessRule). \
+                    filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
+                    first()
+
+                # get splitter
+                splitter = self._get_splitter(processing_rule)
+
+                # split to documents
+                documents = self._step_split(
+                    text_docs=text_docs,
+                    splitter=splitter,
+                    dataset=dataset,
+                    dataset_document=dataset_document,
+                    processing_rule=processing_rule
+                )
+
+                # build index
+                self._build_index(
+                    dataset=dataset,
+                    dataset_document=dataset_document,
+                    documents=documents
+                )
+            except DocumentIsPausedException:
+                raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
+            except ProviderTokenNotInitError as e:
+                dataset_document.indexing_status = 'error'
+                dataset_document.error = str(e.description)
+                dataset_document.stopped_at = datetime.datetime.utcnow()
+                db.session.commit()
+            except Exception as e:
+                logging.exception("consume document failed")
+                dataset_document.indexing_status = 'error'
+                dataset_document.error = str(e)
+                dataset_document.stopped_at = datetime.datetime.utcnow()
+                db.session.commit()
+
+    def run_in_splitting_status(self, dataset_document: DatasetDocument):
+        """Run the indexing process when the index_status is splitting."""
+        try:
             # get dataset
             dataset = Dataset.query.filter_by(
-                id=document.dataset_id
+                id=dataset_document.dataset_id
             ).first()
 
             if not dataset:
                 raise ValueError("no dataset found")
 
+            # get exist document_segment list and delete
+            document_segments = DocumentSegment.query.filter_by(
+                dataset_id=dataset.id,
+                document_id=dataset_document.id
+            ).all()
+
+            db.session.delete(document_segments)
+            db.session.commit()
+
             # load file
-            text_docs = self._load_data(document)
+            text_docs = self._load_data(dataset_document)
 
             # get the process rule
             processing_rule = db.session.query(DatasetProcessRule). \
-                filter(DatasetProcessRule.id == document.dataset_process_rule_id). \
+                filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
                 first()
 
-            # get node parser for splitting
-            node_parser = self._get_node_parser(processing_rule)
+            # get splitter
+            splitter = self._get_splitter(processing_rule)
 
-            # split to nodes
-            nodes = self._step_split(
+            # split to documents
+            documents = self._step_split(
                 text_docs=text_docs,
-                node_parser=node_parser,
+                splitter=splitter,
                 dataset=dataset,
-                document=document,
+                dataset_document=dataset_document,
                 processing_rule=processing_rule
             )
 
             # build index
             self._build_index(
                 dataset=dataset,
-                document=document,
-                nodes=nodes
+                dataset_document=dataset_document,
+                documents=documents
             )
+        except DocumentIsPausedException:
+            raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
+        except ProviderTokenNotInitError as e:
+            dataset_document.indexing_status = 'error'
+            dataset_document.error = str(e.description)
+            dataset_document.stopped_at = datetime.datetime.utcnow()
+            db.session.commit()
+        except Exception as e:
+            logging.exception("consume document failed")
+            dataset_document.indexing_status = 'error'
+            dataset_document.error = str(e)
+            dataset_document.stopped_at = datetime.datetime.utcnow()
+            db.session.commit()
 
-    def run_in_splitting_status(self, document: Document):
-        """Run the indexing process when the index_status is splitting."""
-        # get dataset
-        dataset = Dataset.query.filter_by(
-            id=document.dataset_id
-        ).first()
-
-        if not dataset:
-            raise ValueError("no dataset found")
-
-        # get exist document_segment list and delete
-        document_segments = DocumentSegment.query.filter_by(
-            dataset_id=dataset.id,
-            document_id=document.id
-        ).all()
-        db.session.delete(document_segments)
-        db.session.commit()
-        # load file
-        text_docs = self._load_data(document)
-
-        # get the process rule
-        processing_rule = db.session.query(DatasetProcessRule). \
-            filter(DatasetProcessRule.id == document.dataset_process_rule_id). \
-            first()
-
-        # get node parser for splitting
-        node_parser = self._get_node_parser(processing_rule)
+    def run_in_indexing_status(self, dataset_document: DatasetDocument):
+        """Run the indexing process when the index_status is indexing."""
+        try:
+            # get dataset
+            dataset = Dataset.query.filter_by(
+                id=dataset_document.dataset_id
+            ).first()
 
-        # split to nodes
-        nodes = self._step_split(
-            text_docs=text_docs,
-            node_parser=node_parser,
-            dataset=dataset,
-            document=document,
-            processing_rule=processing_rule
-        )
+            if not dataset:
+                raise ValueError("no dataset found")
 
-        # build index
-        self._build_index(
-            dataset=dataset,
-            document=document,
-            nodes=nodes
-        )
+            # get exist document_segment list and delete
+            document_segments = DocumentSegment.query.filter_by(
+                dataset_id=dataset.id,
+                document_id=dataset_document.id
+            ).all()
+
+            documents = []
+            if document_segments:
+                for document_segment in document_segments:
+                    # transform segment to node
+                    if document_segment.status != "completed":
+                        document = Document(
+                            page_content=document_segment.content,
+                            metadata={
+                                "doc_id": document_segment.index_node_id,
+                                "doc_hash": document_segment.index_node_hash,
+                                "document_id": document_segment.document_id,
+                                "dataset_id": document_segment.dataset_id,
+                            }
+                        )
+
+                        documents.append(document)
 
-    def run_in_indexing_status(self, document: Document):
-        """Run the indexing process when the index_status is indexing."""
-        # get dataset
-        dataset = Dataset.query.filter_by(
-            id=document.dataset_id
-        ).first()
-
-        if not dataset:
-            raise ValueError("no dataset found")
-
-        # get exist document_segment list and delete
-        document_segments = DocumentSegment.query.filter_by(
-            dataset_id=dataset.id,
-            document_id=document.id
-        ).all()
-        nodes = []
-        if document_segments:
-            for document_segment in document_segments:
-                # transform segment to node
-                if document_segment.status != "completed":
-                    relationships = {
-                        DocumentRelationship.SOURCE: document_segment.document_id,
-                    }
-
-                    previous_segment = document_segment.previous_segment
-                    if previous_segment:
-                        relationships[DocumentRelationship.PREVIOUS] = previous_segment.index_node_id
-
-                    next_segment = document_segment.next_segment
-                    if next_segment:
-                        relationships[DocumentRelationship.NEXT] = next_segment.index_node_id
-                    node = Node(
-                        doc_id=document_segment.index_node_id,
-                        doc_hash=document_segment.index_node_hash,
-                        text=document_segment.content,
-                        extra_info=None,
-                        node_info=None,
-                        relationships=relationships
-                    )
-                    nodes.append(node)
-
-        # build index
-        self._build_index(
-            dataset=dataset,
-            document=document,
-            nodes=nodes
-        )
+            # build index
+            self._build_index(
+                dataset=dataset,
+                dataset_document=dataset_document,
+                documents=documents
+            )
+        except DocumentIsPausedException:
+            raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
+        except ProviderTokenNotInitError as e:
+            dataset_document.indexing_status = 'error'
+            dataset_document.error = str(e.description)
+            dataset_document.stopped_at = datetime.datetime.utcnow()
+            db.session.commit()
+        except Exception as e:
+            logging.exception("consume document failed")
+            dataset_document.indexing_status = 'error'
+            dataset_document.error = str(e)
+            dataset_document.stopped_at = datetime.datetime.utcnow()
+            db.session.commit()
 
     def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict) -> dict:
         """
@@ -179,28 +214,28 @@ class IndexingRunner:
         total_segments = 0
         for file_detail in file_details:
             # load data from file
-            text_docs = self._load_data_from_file(file_detail)
+            text_docs = FileExtractor.load(file_detail)
 
             processing_rule = DatasetProcessRule(
                 mode=tmp_processing_rule["mode"],
                 rules=json.dumps(tmp_processing_rule["rules"])
             )
 
-            # get node parser for splitting
-            node_parser = self._get_node_parser(processing_rule)
+            # get splitter
+            splitter = self._get_splitter(processing_rule)
 
-            # split to nodes
-            nodes = self._split_to_nodes(
+            # split to documents
+            documents = self._split_to_documents(
                 text_docs=text_docs,
-                node_parser=node_parser,
+                splitter=splitter,
                 processing_rule=processing_rule
             )
-            total_segments += len(nodes)
-            for node in nodes:
+            total_segments += len(documents)
+            for document in documents:
                 if len(preview_texts) < 5:
-                    preview_texts.append(node.get_text())
+                    preview_texts.append(document.page_content)
 
-                tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text())
+                tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
 
         return {
             "total_segments": total_segments,
@@ -230,35 +265,36 @@ class IndexingRunner:
             ).first()
             if not data_source_binding:
                 raise ValueError('Data source binding not found.')
-            reader = NotionPageReader(integration_token=data_source_binding.access_token)
+
             for page in notion_info['pages']:
-                if page['type'] == 'page':
-                    page_ids = [page['page_id']]
-                    documents = reader.load_data_as_documents(page_ids=page_ids)
-                elif page['type'] == 'database':
-                    documents = reader.load_data_as_documents(database_id=page['page_id'])
-                else:
-                    documents = []
+                loader = NotionLoader(
+                    notion_access_token=data_source_binding.access_token,
+                    notion_workspace_id=workspace_id,
+                    notion_obj_id=page['page_id'],
+                    notion_page_type=page['type']
+                )
+                documents = loader.load()
+
                 processing_rule = DatasetProcessRule(
                     mode=tmp_processing_rule["mode"],
                     rules=json.dumps(tmp_processing_rule["rules"])
                 )
 
-                # get node parser for splitting
-                node_parser = self._get_node_parser(processing_rule)
+                # get splitter
+                splitter = self._get_splitter(processing_rule)
 
-                # split to nodes
-                nodes = self._split_to_nodes(
+                # split to documents
+                documents = self._split_to_documents(
                     text_docs=documents,
-                    node_parser=node_parser,
+                    splitter=splitter,
                     processing_rule=processing_rule
                 )
-                total_segments += len(nodes)
-                for node in nodes:
+                total_segments += len(documents)
+                for document in documents:
                     if len(preview_texts) < 5:
-                        preview_texts.append(node.get_text())
+                        preview_texts.append(document.page_content)
 
-                    tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text())
+                    tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
 
         return {
             "total_segments": total_segments,
@@ -268,14 +304,14 @@ class IndexingRunner:
             "preview": preview_texts
         }
 
-    def _load_data(self, document: Document) -> List[Document]:
+    def _load_data(self, dataset_document: DatasetDocument) -> List[Document]:
         # load file
-        if document.data_source_type not in ["upload_file", "notion_import"]:
+        if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
             return []
 
-        data_source_info = document.data_source_info_dict
+        data_source_info = dataset_document.data_source_info_dict
         text_docs = []
-        if document.data_source_type == 'upload_file':
+        if dataset_document.data_source_type == 'upload_file':
             if not data_source_info or 'upload_file_id' not in data_source_info:
                 raise ValueError("no upload file found")
 
@@ -283,47 +319,28 @@ class IndexingRunner:
                 filter(UploadFile.id == data_source_info['upload_file_id']). \
                 one_or_none()
 
-            text_docs = self._load_data_from_file(file_detail)
-        elif document.data_source_type == 'notion_import':
-            if not data_source_info or 'notion_page_id' not in data_source_info \
-                    or 'notion_workspace_id' not in data_source_info:
-                raise ValueError("no notion page found")
-            workspace_id = data_source_info['notion_workspace_id']
-            page_id = data_source_info['notion_page_id']
-            page_type = data_source_info['type']
-            data_source_binding = DataSourceBinding.query.filter(
-                db.and_(
-                    DataSourceBinding.tenant_id == document.tenant_id,
-                    DataSourceBinding.provider == 'notion',
-                    DataSourceBinding.disabled == False,
-                    DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
-                )
-            ).first()
-            if not data_source_binding:
-                raise ValueError('Data source binding not found.')
-            if page_type == 'page':
-                # add page last_edited_time to data_source_info
-                self._get_notion_page_last_edited_time(page_id, data_source_binding.access_token, document)
-                text_docs = self._load_page_data_from_notion(page_id, data_source_binding.access_token)
-            elif page_type == 'database':
-                # add page last_edited_time to data_source_info
-                self._get_notion_database_last_edited_time(page_id, data_source_binding.access_token, document)
-                text_docs = self._load_database_data_from_notion(page_id, data_source_binding.access_token)
+            text_docs = FileExtractor.load(file_detail)
+        elif dataset_document.data_source_type == 'notion_import':
+            loader = NotionLoader.from_document(dataset_document)
+            text_docs = loader.load()
+
         # update document status to splitting
         self._update_document_index_status(
-            document_id=document.id,
+            document_id=dataset_document.id,
             after_indexing_status="splitting",
             extra_update_params={
-                Document.word_count: sum([len(text_doc.text) for text_doc in text_docs]),
-                Document.parsing_completed_at: datetime.datetime.utcnow()
+                DatasetDocument.word_count: sum([len(text_doc.page_content) for text_doc in text_docs]),
+                DatasetDocument.parsing_completed_at: datetime.datetime.utcnow()
             }
         )
 
         # replace doc id to document model id
+        text_docs = cast(List[Document], text_docs)
         for text_doc in text_docs:
             # remove invalid symbol
-            text_doc.text = self.filter_string(text_doc.get_text())
-            text_doc.doc_id = document.id
+            text_doc.page_content = self.filter_string(text_doc.page_content)
+            text_doc.metadata['document_id'] = dataset_document.id
+            text_doc.metadata['dataset_id'] = dataset_document.dataset_id
 
         return text_docs
 
@@ -331,61 +348,7 @@ class IndexingRunner:
         pattern = re.compile('[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]')
         return pattern.sub('', text)
 
-    def _load_data_from_file(self, upload_file: UploadFile) -> List[Document]:
-        with tempfile.TemporaryDirectory() as temp_dir:
-            suffix = Path(upload_file.key).suffix
-            filepath = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
-            self.storage.download(upload_file.key, filepath)
-
-            file_extractor = DEFAULT_FILE_EXTRACTOR.copy()
-            file_extractor[".markdown"] = MarkdownParser()
-            file_extractor[".md"] = MarkdownParser()
-            file_extractor[".html"] = HTMLParser()
-            file_extractor[".htm"] = HTMLParser()
-            file_extractor[".pdf"] = PDFParser({'upload_file': upload_file})
-            file_extractor[".xlsx"] = XLSXParser()
-
-            loader = SimpleDirectoryReader(input_files=[filepath], file_extractor=file_extractor)
-            text_docs = loader.load_data()
-
-            return text_docs
-
-    def _load_page_data_from_notion(self, page_id: str, access_token: str) -> List[Document]:
-        page_ids = [page_id]
-        reader = NotionPageReader(integration_token=access_token)
-        text_docs = reader.load_data_as_documents(page_ids=page_ids)
-        return text_docs
-
-    def _load_database_data_from_notion(self, database_id: str, access_token: str) -> List[Document]:
-        reader = NotionPageReader(integration_token=access_token)
-        text_docs = reader.load_data_as_documents(database_id=database_id)
-        return text_docs
-
-    def _get_notion_page_last_edited_time(self, page_id: str, access_token: str, document: Document):
-        reader = NotionPageReader(integration_token=access_token)
-        last_edited_time = reader.get_page_last_edited_time(page_id)
-        data_source_info = document.data_source_info_dict
-        data_source_info['last_edited_time'] = last_edited_time
-        update_params = {
-            Document.data_source_info: json.dumps(data_source_info)
-        }
-
-        Document.query.filter_by(id=document.id).update(update_params)
-        db.session.commit()
-
-    def _get_notion_database_last_edited_time(self, page_id: str, access_token: str, document: Document):
-        reader = NotionPageReader(integration_token=access_token)
-        last_edited_time = reader.get_database_last_edited_time(page_id)
-        data_source_info = document.data_source_info_dict
-        data_source_info['last_edited_time'] = last_edited_time
-        update_params = {
-            Document.data_source_info: json.dumps(data_source_info)
-        }
-
-        Document.query.filter_by(id=document.id).update(update_params)
-        db.session.commit()
-
-    def _get_node_parser(self, processing_rule: DatasetProcessRule) -> NodeParser:
+    def _get_splitter(self, processing_rule: DatasetProcessRule) -> TextSplitter:
         """
         Get the NodeParser object according to the processing rule.
         """
@@ -414,68 +377,83 @@ class IndexingRunner:
                 separators=["\n\n", "。", ".", " ", ""]
             )
 
-        return SimpleNodeParser(text_splitter=character_splitter, include_extra_info=True)
+        return character_splitter
 
-    def _step_split(self, text_docs: List[Document], node_parser: NodeParser,
-                    dataset: Dataset, document: Document, processing_rule: DatasetProcessRule) -> List[Node]:
+    def _step_split(self, text_docs: List[Document], splitter: TextSplitter,
+                    dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \
+            -> List[Document]:
         """
-        Split the text documents into nodes and save them to the document segment.
+        Split the text documents into documents and save them to the document segment.
         """
-        nodes = self._split_to_nodes(
+        documents = self._split_to_documents(
             text_docs=text_docs,
-            node_parser=node_parser,
+            splitter=splitter,
             processing_rule=processing_rule
         )
 
         # save node to document segment
         doc_store = DatesetDocumentStore(
             dataset=dataset,
-            user_id=document.created_by,
+            user_id=dataset_document.created_by,
             embedding_model_name=self.embedding_model_name,
-            document_id=document.id
+            document_id=dataset_document.id
         )
+
         # add document segments
-        doc_store.add_documents(nodes)
+        doc_store.add_documents(documents)
 
         # update document status to indexing
         cur_time = datetime.datetime.utcnow()
         self._update_document_index_status(
-            document_id=document.id,
+            document_id=dataset_document.id,
             after_indexing_status="indexing",
             extra_update_params={
-                Document.cleaning_completed_at: cur_time,
-                Document.splitting_completed_at: cur_time,
+                DatasetDocument.cleaning_completed_at: cur_time,
+                DatasetDocument.splitting_completed_at: cur_time,
             }
         )
 
         # update segment status to indexing
         self._update_segments_by_document(
-            document_id=document.id,
+            dataset_document_id=dataset_document.id,
             update_params={
                 DocumentSegment.status: "indexing",
                 DocumentSegment.indexing_at: datetime.datetime.utcnow()
             }
         )
 
-        return nodes
+        return documents
 
-    def _split_to_nodes(self, text_docs: List[Document], node_parser: NodeParser,
-                        processing_rule: DatasetProcessRule) -> List[Node]:
+    def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,
+                        processing_rule: DatasetProcessRule) -> List[Document]:
         """
         Split the text documents into nodes.
         """
-        all_nodes = []
+        all_documents = []
         for text_doc in text_docs:
             # document clean
-            document_text = self._document_clean(text_doc.get_text(), processing_rule)
-            text_doc.text = document_text
+            document_text = self._document_clean(text_doc.page_content, processing_rule)
+            text_doc.page_content = document_text
 
             # parse document to nodes
-            nodes = node_parser.get_nodes_from_documents([text_doc])
-            nodes = [node for node in nodes if node.text is not None and node.text.strip()]
-            all_nodes.extend(nodes)
+            documents = splitter.split_documents([text_doc])
+
+            split_documents = []
+            for document in documents:
+                if document.page_content is None or not document.page_content.strip():
+                    continue
+
+                doc_id = str(uuid.uuid4())
+                hash = helper.generate_text_hash(document.page_content)
+
+                document.metadata['doc_id'] = doc_id
+                document.metadata['doc_hash'] = hash
+
+                split_documents.append(document)
+
+            all_documents.extend(split_documents)
 
-        return all_nodes
+        return all_documents
 
     def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str:
         """
@@ -506,37 +484,38 @@ class IndexingRunner:
 
         return text
 
-    def _build_index(self, dataset: Dataset, document: Document, nodes: List[Node]) -> None:
+    def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None:
         """
         Build the index for the document.
         """
-        vector_index = VectorIndex(dataset=dataset)
-        keyword_table_index = KeywordTableIndex(dataset=dataset)
+        vector_index = IndexBuilder.get_index(dataset, 'high_quality')
+        keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
 
         # chunk nodes by chunk size
         indexing_start_at = time.perf_counter()
         tokens = 0
         chunk_size = 100
-        for i in range(0, len(nodes), chunk_size):
+        for i in range(0, len(documents), chunk_size):
             # check document is paused
-            self._check_document_paused_status(document.id)
-            chunk_nodes = nodes[i:i + chunk_size]
+            self._check_document_paused_status(dataset_document.id)
+            chunk_documents = documents[i:i + chunk_size]
 
             tokens += sum(
-                TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) for node in chunk_nodes
+                TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
+                for document in chunk_documents
             )
 
             # save vector index
-            if dataset.indexing_technique == "high_quality":
-                vector_index.add_nodes(chunk_nodes)
+            if vector_index:
+                vector_index.add_texts(chunk_documents)
 
             # save keyword index
-            keyword_table_index.add_nodes(chunk_nodes)
+            keyword_table_index.add_texts(chunk_documents)
 
-            node_ids = [node.doc_id for node in chunk_nodes]
+            document_ids = [document.metadata['doc_id'] for document in chunk_documents]
             db.session.query(DocumentSegment).filter(
-                DocumentSegment.document_id == document.id,
-                DocumentSegment.index_node_id.in_(node_ids),
+                DocumentSegment.document_id == dataset_document.id,
+                DocumentSegment.index_node_id.in_(document_ids),
                 DocumentSegment.status == "indexing"
             ).update({
                 DocumentSegment.status: "completed",
@@ -549,12 +528,12 @@ class IndexingRunner:
 
         # update document status to completed
         self._update_document_index_status(
-            document_id=document.id,
+            document_id=dataset_document.id,
             after_indexing_status="completed",
             extra_update_params={
-                Document.tokens: tokens,
-                Document.completed_at: datetime.datetime.utcnow(),
-                Document.indexing_latency: indexing_end_at - indexing_start_at,
+                DatasetDocument.tokens: tokens,
+                DatasetDocument.completed_at: datetime.datetime.utcnow(),
+                DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
             }
         )
 
@@ -569,25 +548,25 @@ class IndexingRunner:
         """
         Update the document indexing status.
         """
-        count = Document.query.filter_by(id=document_id, is_paused=True).count()
+        count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
         if count > 0:
             raise DocumentIsPausedException()
 
         update_params = {
-            Document.indexing_status: after_indexing_status
+            DatasetDocument.indexing_status: after_indexing_status
         }
 
         if extra_update_params:
             update_params.update(extra_update_params)
 
-        Document.query.filter_by(id=document_id).update(update_params)
+        DatasetDocument.query.filter_by(id=document_id).update(update_params)
         db.session.commit()
 
-    def _update_segments_by_document(self, document_id: str, update_params: dict) -> None:
+    def _update_segments_by_document(self, dataset_document_id: str, update_params: dict) -> None:
         """
         Update the document segment by document id.
         """
-        DocumentSegment.query.filter_by(document_id=document_id).update(update_params)
+        DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
         db.session.commit()
 
 

+ 17 - 14
api/core/llm/llm_builder.py

@@ -1,7 +1,6 @@
-from typing import Union, Optional
+from typing import Union, Optional, List
 
-from langchain.callbacks import CallbackManager
-from langchain.llms.fake import FakeListLLM
+from langchain.callbacks.base import BaseCallbackHandler
 
 from core.constant import llm_constant
 from core.llm.error import ProviderTokenNotInitError
@@ -32,12 +31,11 @@ class LLMBuilder:
     """
 
     @classmethod
-    def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI, FakeListLLM]:
-        if model_name == 'fake':
-            return FakeListLLM(responses=[])
-
+    def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
         provider = cls.get_default_provider(tenant_id)
 
+        model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
+
         mode = cls.get_mode_by_model(model_name)
         if mode == 'chat':
             if provider == 'openai':
@@ -52,16 +50,21 @@ class LLMBuilder:
         else:
             raise ValueError(f"model name {model_name} is not supported.")
 
-        model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
+
+        model_kwargs = {
+            'top_p': kwargs.get('top_p', 1),
+            'frequency_penalty': kwargs.get('frequency_penalty', 0),
+            'presence_penalty': kwargs.get('presence_penalty', 0),
+        }
+
+        model_extras_kwargs = model_kwargs if mode == 'completion' else {'model_kwargs': model_kwargs}
 
         return llm_cls(
             model_name=model_name,
             temperature=kwargs.get('temperature', 0),
             max_tokens=kwargs.get('max_tokens', 256),
-            top_p=kwargs.get('top_p', 1),
-            frequency_penalty=kwargs.get('frequency_penalty', 0),
-            presence_penalty=kwargs.get('presence_penalty', 0),
-            callback_manager=kwargs.get('callback_manager', None),
+            **model_extras_kwargs,
+            callbacks=kwargs.get('callbacks', None),
             streaming=kwargs.get('streaming', False),
             # request_timeout=None
             **model_credentials
@@ -69,7 +72,7 @@ class LLMBuilder:
 
     @classmethod
     def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
-                          callback_manager: Optional[CallbackManager] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
+                          callbacks: Optional[List[BaseCallbackHandler]] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
         model_name = model.get("name")
         completion_params = model.get("completion_params", {})
 
@@ -82,7 +85,7 @@ class LLMBuilder:
             frequency_penalty=completion_params.get('frequency_penalty', 0.1),
             presence_penalty=completion_params.get('presence_penalty', 0.1),
             streaming=streaming,
-            callback_manager=callback_manager
+            callbacks=callbacks
         )
 
     @classmethod

+ 4 - 1
api/core/llm/provider/azure_provider.py

@@ -42,7 +42,10 @@ class AzureProvider(BaseProvider):
         """
         config = self.get_provider_api_key(model_id=model_id)
         config['openai_api_type'] = 'azure'
-        config['deployment_name'] = model_id.replace('.', '') if model_id else None
+        if model_id == 'text-embedding-ada-002':
+            config['deployment'] = model_id.replace('.', '') if model_id else None
+        else:
+            config['deployment_name'] = model_id.replace('.', '') if model_id else None
         return config
 
     def get_provider_name(self):

+ 13 - 50
api/core/llm/streamable_azure_chat_open_ai.py

@@ -1,3 +1,4 @@
+from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun, Callbacks
 from langchain.schema import BaseMessage, ChatResult, LLMResult
 from langchain.chat_models import AzureChatOpenAI
 from typing import Optional, List, Dict, Any
@@ -68,60 +69,22 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
 
         return message_tokens
 
-    def _generate(
-            self, messages: List[BaseMessage], stop: Optional[List[str]] = None
-    ) -> ChatResult:
-        self.callback_manager.on_llm_start(
-            {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
-            verbose=self.verbose
-        )
-
-        chat_result = super()._generate(messages, stop)
-
-        result = LLMResult(
-            generations=[chat_result.generations],
-            llm_output=chat_result.llm_output
-        )
-        self.callback_manager.on_llm_end(result, verbose=self.verbose)
-
-        return chat_result
-
-    async def _agenerate(
-            self, messages: List[BaseMessage], stop: Optional[List[str]] = None
-    ) -> ChatResult:
-        if self.callback_manager.is_async:
-            await self.callback_manager.on_llm_start(
-                {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
-                verbose=self.verbose
-            )
-        else:
-            self.callback_manager.on_llm_start(
-                {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
-                verbose=self.verbose
-            )
-
-        chat_result = super()._generate(messages, stop)
-
-        result = LLMResult(
-            generations=[chat_result.generations],
-            llm_output=chat_result.llm_output
-        )
-
-        if self.callback_manager.is_async:
-            await self.callback_manager.on_llm_end(result, verbose=self.verbose)
-        else:
-            self.callback_manager.on_llm_end(result, verbose=self.verbose)
-
-        return chat_result
-
     @handle_llm_exceptions
     def generate(
-            self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
+            self,
+            messages: List[List[BaseMessage]],
+            stop: Optional[List[str]] = None,
+            callbacks: Callbacks = None,
+            **kwargs: Any,
     ) -> LLMResult:
-        return super().generate(messages, stop)
+        return super().generate(messages, stop, callbacks, **kwargs)
 
     @handle_llm_exceptions_async
     async def agenerate(
-            self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
+            self,
+            messages: List[List[BaseMessage]],
+            stop: Optional[List[str]] = None,
+            callbacks: Callbacks = None,
+            **kwargs: Any,
     ) -> LLMResult:
-        return await super().agenerate(messages, stop)
+        return await super().agenerate(messages, stop, callbacks, **kwargs)

+ 13 - 6
api/core/llm/streamable_azure_open_ai.py

@@ -1,5 +1,4 @@
-import os
-
+from langchain.callbacks.manager import Callbacks
 from langchain.llms import AzureOpenAI
 from langchain.schema import LLMResult
 from typing import Optional, List, Dict, Mapping, Any
@@ -53,12 +52,20 @@ class StreamableAzureOpenAI(AzureOpenAI):
 
     @handle_llm_exceptions
     def generate(
-            self, prompts: List[str], stop: Optional[List[str]] = None
+            self,
+            prompts: List[str],
+            stop: Optional[List[str]] = None,
+            callbacks: Callbacks = None,
+            **kwargs: Any,
     ) -> LLMResult:
-        return super().generate(prompts, stop)
+        return super().generate(prompts, stop, callbacks, **kwargs)
 
     @handle_llm_exceptions_async
     async def agenerate(
-            self, prompts: List[str], stop: Optional[List[str]] = None
+            self,
+            prompts: List[str],
+            stop: Optional[List[str]] = None,
+            callbacks: Callbacks = None,
+            **kwargs: Any,
     ) -> LLMResult:
-        return await super().agenerate(prompts, stop)
+        return await super().agenerate(prompts, stop, callbacks, **kwargs)

+ 14 - 48
api/core/llm/streamable_chat_open_ai.py

@@ -1,6 +1,7 @@
 import os
 
-from langchain.schema import BaseMessage, ChatResult, LLMResult
+from langchain.callbacks.manager import Callbacks
+from langchain.schema import BaseMessage, LLMResult
 from langchain.chat_models import ChatOpenAI
 from typing import Optional, List, Dict, Any
 
@@ -70,57 +71,22 @@ class StreamableChatOpenAI(ChatOpenAI):
 
         return message_tokens
 
-    def _generate(
-        self, messages: List[BaseMessage], stop: Optional[List[str]] = None
-    ) -> ChatResult:
-        self.callback_manager.on_llm_start(
-            {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
-        )
-
-        chat_result = super()._generate(messages, stop)
-
-        result = LLMResult(
-            generations=[chat_result.generations],
-            llm_output=chat_result.llm_output
-        )
-        self.callback_manager.on_llm_end(result, verbose=self.verbose)
-
-        return chat_result
-
-    async def _agenerate(
-        self, messages: List[BaseMessage], stop: Optional[List[str]] = None
-    ) -> ChatResult:
-        if self.callback_manager.is_async:
-            await self.callback_manager.on_llm_start(
-                {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
-            )
-        else:
-            self.callback_manager.on_llm_start(
-                {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
-            )
-
-        chat_result = super()._generate(messages, stop)
-
-        result = LLMResult(
-            generations=[chat_result.generations],
-            llm_output=chat_result.llm_output
-        )
-
-        if self.callback_manager.is_async:
-            await self.callback_manager.on_llm_end(result, verbose=self.verbose)
-        else:
-            self.callback_manager.on_llm_end(result, verbose=self.verbose)
-
-        return chat_result
-
     @handle_llm_exceptions
     def generate(
-            self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
+            self,
+            messages: List[List[BaseMessage]],
+            stop: Optional[List[str]] = None,
+            callbacks: Callbacks = None,
+            **kwargs: Any,
     ) -> LLMResult:
-        return super().generate(messages, stop)
+        return super().generate(messages, stop, callbacks, **kwargs)
 
     @handle_llm_exceptions_async
     async def agenerate(
-            self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
+            self,
+            messages: List[List[BaseMessage]],
+            stop: Optional[List[str]] = None,
+            callbacks: Callbacks = None,
+            **kwargs: Any,
     ) -> LLMResult:
-        return await super().agenerate(messages, stop)
+        return await super().agenerate(messages, stop, callbacks, **kwargs)

+ 13 - 5
api/core/llm/streamable_open_ai.py

@@ -1,5 +1,6 @@
 import os
 
+from langchain.callbacks.manager import Callbacks
 from langchain.schema import LLMResult
 from typing import Optional, List, Dict, Any, Mapping
 from langchain import OpenAI
@@ -48,15 +49,22 @@ class StreamableOpenAI(OpenAI):
             "organization": self.openai_organization if self.openai_organization else None,
         }}
 
-
     @handle_llm_exceptions
     def generate(
-            self, prompts: List[str], stop: Optional[List[str]] = None
+            self,
+            prompts: List[str],
+            stop: Optional[List[str]] = None,
+            callbacks: Callbacks = None,
+            **kwargs: Any,
     ) -> LLMResult:
-        return super().generate(prompts, stop)
+        return super().generate(prompts, stop, callbacks, **kwargs)
 
     @handle_llm_exceptions_async
     async def agenerate(
-            self, prompts: List[str], stop: Optional[List[str]] = None
+            self,
+            prompts: List[str],
+            stop: Optional[List[str]] = None,
+            callbacks: Callbacks = None,
+            **kwargs: Any,
     ) -> LLMResult:
-        return await super().agenerate(prompts, stop)
+        return await super().agenerate(prompts, stop, callbacks, **kwargs)

+ 1 - 1
api/core/memory/read_only_conversation_token_db_string_buffer_shared_memory.py

@@ -1,7 +1,7 @@
 from typing import Any, List, Dict
 
 from langchain.memory.chat_memory import BaseChatMemory
-from langchain.schema import get_buffer_string, BaseMessage, BaseLanguageModel
+from langchain.schema import get_buffer_string
 
 from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
     ReadOnlyConversationTokenDBBufferSharedMemory

+ 0 - 19
api/core/prompt/prompts.py

@@ -1,5 +1,3 @@
-from llama_index import QueryKeywordExtractPrompt
-
 CONVERSATION_TITLE_PROMPT = (
     "Human:{query}\n-----\n"
     "Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.\n"
@@ -45,23 +43,6 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
     "[\"question1\",\"question2\",\"question3\"]\n"
 )
 
-QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL = (
-    "A question is provided below. Given the question, extract up to {max_keywords} "
-    "keywords from the text. Focus on extracting the keywords that we can use "
-    "to best lookup answers to the question. Avoid stopwords."
-    "I am not sure which language the following question is in. "
-    "If the user asked the question in Chinese, please return the keywords in Chinese. "
-    "If the user asked the question in English, please return the keywords in English.\n"
-    "---------------------\n"
-    "{question}\n"
-    "---------------------\n"
-    "Provide keywords in the following comma-separated format: 'KEYWORDS: <keywords>'\n"
-)
-
-QUERY_KEYWORD_EXTRACT_TEMPLATE = QueryKeywordExtractPrompt(
-    QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL
-)
-
 RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \
 the model prompt that best suits the input. 
 You will be provided with the prompt, variables, and an opening statement. 

+ 0 - 0
api/core/index/spiltter/fixed_text_splitter.py → api/core/spiltter/fixed_text_splitter.py


+ 87 - 0
api/core/tool/dataset_index_tool.py

@@ -0,0 +1,87 @@
+from flask import current_app
+from langchain.embeddings import OpenAIEmbeddings
+from langchain.tools import BaseTool
+
+from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
+from core.embedding.cached_embedding import CacheEmbedding
+from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
+from core.index.vector_index.vector_index import VectorIndex
+from core.llm.llm_builder import LLMBuilder
+from models.dataset import Dataset
+
+
+class DatasetTool(BaseTool):
+    """Tool for querying a Dataset."""
+
+    dataset: Dataset
+    k: int = 2
+
+    def _run(self, tool_input: str) -> str:
+        if self.dataset.indexing_technique == "economy":
+            # use keyword table query
+            kw_table_index = KeywordTableIndex(
+                dataset=self.dataset,
+                config=KeywordTableConfig(
+                    max_keywords_per_chunk=5
+                )
+            )
+
+            documents = kw_table_index.search(tool_input, search_kwargs={'k': self.k})
+        else:
+            model_credentials = LLMBuilder.get_model_credentials(
+                tenant_id=self.dataset.tenant_id,
+                model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
+                model_name='text-embedding-ada-002'
+            )
+
+            embeddings = CacheEmbedding(OpenAIEmbeddings(
+                **model_credentials
+            ))
+
+            vector_index = VectorIndex(
+                dataset=self.dataset,
+                config=current_app.config,
+                embeddings=embeddings
+            )
+
+            documents = vector_index.search(
+                tool_input,
+                search_type='similarity',
+                search_kwargs={
+                    'k': self.k
+                }
+            )
+
+            hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
+            hit_callback.on_tool_end(documents)
+
+        return str("\n".join([document.page_content for document in documents]))
+
+    async def _arun(self, tool_input: str) -> str:
+        model_credentials = LLMBuilder.get_model_credentials(
+            tenant_id=self.dataset.tenant_id,
+            model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
+            model_name='text-embedding-ada-002'
+        )
+
+        embeddings = CacheEmbedding(OpenAIEmbeddings(
+            **model_credentials
+        ))
+
+        vector_index = VectorIndex(
+            dataset=self.dataset,
+            config=current_app.config,
+            embeddings=embeddings
+        )
+
+        documents = await vector_index.asearch(
+            tool_input,
+            search_type='similarity',
+            search_kwargs={
+                'k': 10
+            }
+        )
+
+        hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
+        hit_callback.on_tool_end(documents)
+        return str("\n".join([document.page_content for document in documents]))

+ 0 - 73
api/core/tool/dataset_tool_builder.py

@@ -1,73 +0,0 @@
-from typing import Optional
-
-from langchain.callbacks import CallbackManager
-from llama_index.langchain_helpers.agents import IndexToolConfig
-
-from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
-from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
-from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
-from core.index.keyword_table_index import KeywordTableIndex
-from core.index.vector_index import VectorIndex
-from core.prompt.prompts import QUERY_KEYWORD_EXTRACT_TEMPLATE
-from core.tool.llama_index_tool import EnhanceLlamaIndexTool
-from models.dataset import Dataset
-
-
-class DatasetToolBuilder:
-    @classmethod
-    def build_dataset_tool(cls, dataset: Dataset,
-                           response_mode: str = "no_synthesizer",
-                           callback_handler: Optional[DatasetToolCallbackHandler] = None):
-        if dataset.indexing_technique == "economy":
-            # use keyword table query
-            index = KeywordTableIndex(dataset=dataset).query_index
-
-            if not index:
-                return None
-
-            query_kwargs = {
-                "mode": "default",
-                "response_mode": response_mode,
-                "query_keyword_extract_template": QUERY_KEYWORD_EXTRACT_TEMPLATE,
-                "max_keywords_per_query": 5,
-                # If num_chunks_per_query is too large,
-                # it will slow down the synthesis process due to multiple iterations of refinement.
-                "num_chunks_per_query": 2
-            }
-        else:
-            index = VectorIndex(dataset=dataset).query_index
-
-            if not index:
-                return None
-
-            query_kwargs = {
-                "mode": "default",
-                "response_mode": response_mode,
-                # If top_k is too large,
-                # it will slow down the synthesis process due to multiple iterations of refinement.
-                "similarity_top_k": 2
-            }
-
-        # fulfill description when it is empty
-        description = dataset.description
-        if not description:
-            description = 'useful for when you want to answer queries about the ' + dataset.name
-
-        index_tool_config = IndexToolConfig(
-            index=index,
-            name=f"dataset-{dataset.id}",
-            description=description,
-            index_query_kwargs=query_kwargs,
-            tool_kwargs={
-                "callback_manager": CallbackManager([callback_handler, DifyStdOutCallbackHandler()])
-            },
-            # tool_kwargs={"return_direct": True},
-            # return_direct: Whether to return LLM results directly or process the output data with an Output Parser
-        )
-
-        index_callback_handler = DatasetIndexToolCallbackHandler(dataset_id=dataset.id)
-
-        return EnhanceLlamaIndexTool.from_tool_config(
-            tool_config=index_tool_config,
-            callback_handler=index_callback_handler
-        )

+ 0 - 43
api/core/tool/llama_index_tool.py

@@ -1,43 +0,0 @@
-from typing import Dict
-
-from langchain.tools import BaseTool
-from llama_index.indices.base import BaseGPTIndex
-from llama_index.langchain_helpers.agents import IndexToolConfig
-from pydantic import Field
-
-from core.callback_handler.index_tool_callback_handler import IndexToolCallbackHandler
-
-
-class EnhanceLlamaIndexTool(BaseTool):
-    """Tool for querying a LlamaIndex."""
-
-    # NOTE: name/description still needs to be set
-    index: BaseGPTIndex
-    query_kwargs: Dict = Field(default_factory=dict)
-    return_sources: bool = False
-    callback_handler: IndexToolCallbackHandler
-
-    @classmethod
-    def from_tool_config(cls, tool_config: IndexToolConfig,
-                         callback_handler: IndexToolCallbackHandler) -> "EnhanceLlamaIndexTool":
-        """Create a tool from a tool config."""
-        return_sources = tool_config.tool_kwargs.pop("return_sources", False)
-        return cls(
-            index=tool_config.index,
-            callback_handler=callback_handler,
-            name=tool_config.name,
-            description=tool_config.description,
-            return_sources=return_sources,
-            query_kwargs=tool_config.index_query_kwargs,
-            **tool_config.tool_kwargs,
-        )
-
-    def _run(self, tool_input: str) -> str:
-        response = self.index.query(tool_input, **self.query_kwargs)
-        self.callback_handler.on_tool_end(response)
-        return str(response)
-
-    async def _arun(self, tool_input: str) -> str:
-        response = await self.index.aquery(tool_input, **self.query_kwargs)
-        self.callback_handler.on_tool_end(response)
-        return str(response)

+ 0 - 34
api/core/vector_store/base.py

@@ -1,34 +0,0 @@
-from abc import ABC, abstractmethod
-from typing import Optional
-
-from llama_index import ServiceContext, GPTVectorStoreIndex
-from llama_index.data_structs import Node
-from llama_index.vector_stores.types import VectorStore
-
-
-class BaseVectorStoreClient(ABC):
-    @abstractmethod
-    def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
-        raise NotImplementedError
-
-    @abstractmethod
-    def to_index_config(self, index_id: str) -> dict:
-        raise NotImplementedError
-
-
-class BaseGPTVectorStoreIndex(GPTVectorStoreIndex):
-    def delete_node(self, node_id: str):
-        self._vector_store.delete_node(node_id)
-
-    def exists_by_node_id(self, node_id: str) -> bool:
-        return self._vector_store.exists_by_node_id(node_id)
-
-
-class EnhanceVectorStore(ABC):
-    @abstractmethod
-    def delete_node(self, node_id: str):
-        pass
-
-    @abstractmethod
-    def exists_by_node_id(self, node_id: str) -> bool:
-        pass

+ 69 - 0
api/core/vector_store/qdrant_vector_store.py

@@ -0,0 +1,69 @@
+from typing import cast, Any
+
+from langchain.schema import Document
+from langchain.vectorstores import Qdrant
+from qdrant_client.http.models import Filter, PointIdsList, FilterSelector
+from qdrant_client.local.qdrant_local import QdrantLocal
+
+
+class QdrantVectorStore(Qdrant):
+    def del_texts(self, filter: Filter):
+        if not filter:
+            raise ValueError('filter must not be empty')
+
+        self._reload_if_needed()
+
+        self.client.delete(
+            collection_name=self.collection_name,
+            points_selector=FilterSelector(
+                filter=filter
+            ),
+        )
+
+    def del_text(self, uuid: str) -> None:
+        self._reload_if_needed()
+
+        self.client.delete(
+            collection_name=self.collection_name,
+            points_selector=PointIdsList(
+                points=[uuid],
+            ),
+        )
+
+    def text_exists(self, uuid: str) -> bool:
+        self._reload_if_needed()
+
+        response = self.client.retrieve(
+            collection_name=self.collection_name,
+            ids=[uuid]
+        )
+
+        return len(response) > 0
+
+    def delete(self):
+        self._reload_if_needed()
+
+        self.client.delete_collection(collection_name=self.collection_name)
+
+    @classmethod
+    def _document_from_scored_point(
+            cls,
+            scored_point: Any,
+            content_payload_key: str,
+            metadata_payload_key: str,
+    ) -> Document:
+        if scored_point.payload.get('doc_id'):
+            return Document(
+                page_content=scored_point.payload.get(content_payload_key),
+                metadata={'doc_id': scored_point.id}
+            )
+
+        return Document(
+            page_content=scored_point.payload.get(content_payload_key),
+            metadata=scored_point.payload.get(metadata_payload_key) or {},
+        )
+
+    def _reload_if_needed(self):
+        if isinstance(self.client, QdrantLocal):
+            self.client = cast(QdrantLocal, self.client)
+            self.client._load()

+ 0 - 147
api/core/vector_store/qdrant_vector_store_client.py

@@ -1,147 +0,0 @@
-import os
-from typing import cast, List
-
-from llama_index.data_structs import Node
-from llama_index.data_structs.node_v2 import DocumentRelationship
-from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult
-from qdrant_client.http.models import Payload, Filter
-
-import qdrant_client
-from llama_index import ServiceContext, GPTVectorStoreIndex, GPTQdrantIndex
-from llama_index.data_structs.data_structs_v2 import QdrantIndexDict
-from llama_index.vector_stores import QdrantVectorStore
-from qdrant_client.local.qdrant_local import QdrantLocal
-
-from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore
-
-
-class QdrantVectorStoreClient(BaseVectorStoreClient):
-
-    def __init__(self, url: str, api_key: str, root_path: str):
-        self._client = self.init_from_config(url, api_key, root_path)
-
-    @classmethod
-    def init_from_config(cls, url: str, api_key: str, root_path: str):
-        if url and url.startswith('path:'):
-            path = url.replace('path:', '')
-            if not os.path.isabs(path):
-                path = os.path.join(root_path, path)
-
-            return qdrant_client.QdrantClient(
-                path=path
-            )
-        else:
-            return qdrant_client.QdrantClient(
-                url=url,
-                api_key=api_key,
-            )
-
-    def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
-        index_struct = QdrantIndexDict()
-
-        if self._client is None:
-            raise Exception("Vector client is not initialized.")
-
-        # {"collection_name": "Gpt_index_xxx"}
-        collection_name = config.get('collection_name')
-        if not collection_name:
-            raise Exception("collection_name cannot be None.")
-
-        return GPTQdrantEnhanceIndex(
-            service_context=service_context,
-            index_struct=index_struct,
-            vector_store=QdrantEnhanceVectorStore(
-                client=self._client,
-                collection_name=collection_name
-            )
-        )
-
-    def to_index_config(self, index_id: str) -> dict:
-        return {"collection_name": index_id}
-
-
-class GPTQdrantEnhanceIndex(GPTQdrantIndex, BaseGPTVectorStoreIndex):
-    pass
-
-
-class QdrantEnhanceVectorStore(QdrantVectorStore, EnhanceVectorStore):
-    def delete_node(self, node_id: str):
-        """
-        Delete node from the index.
-
-        :param node_id: node id
-        """
-        from qdrant_client.http import models as rest
-
-        self._reload_if_needed()
-
-        self._client.delete(
-            collection_name=self._collection_name,
-            points_selector=rest.Filter(
-                must=[
-                    rest.FieldCondition(
-                        key="id", match=rest.MatchValue(value=node_id)
-                    )
-                ]
-            ),
-        )
-
-    def exists_by_node_id(self, node_id: str) -> bool:
-        """
-        Get node from the index by node id.
-
-        :param node_id: node id
-        """
-        self._reload_if_needed()
-
-        response = self._client.retrieve(
-            collection_name=self._collection_name,
-            ids=[node_id]
-        )
-
-        return len(response) > 0
-
-    def query(
-        self,
-        query: VectorStoreQuery,
-    ) -> VectorStoreQueryResult:
-        """Query index for top k most similar nodes.
-
-        Args:
-            query (VectorStoreQuery): query
-        """
-        query_embedding = cast(List[float], query.query_embedding)
-
-        self._reload_if_needed()
-
-        response = self._client.search(
-            collection_name=self._collection_name,
-            query_vector=query_embedding,
-            limit=cast(int, query.similarity_top_k),
-            query_filter=cast(Filter, self._build_query_filter(query)),
-            with_vectors=True
-        )
-
-        nodes = []
-        similarities = []
-        ids = []
-        for point in response:
-            payload = cast(Payload, point.payload)
-            node = Node(
-                doc_id=str(point.id),
-                text=payload.get("text"),
-                embedding=point.vector,
-                extra_info=payload.get("extra_info"),
-                relationships={
-                    DocumentRelationship.SOURCE: payload.get("doc_id", "None"),
-                },
-            )
-            nodes.append(node)
-            similarities.append(point.score)
-            ids.append(str(point.id))
-
-        return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
-
-    def _reload_if_needed(self):
-        if isinstance(self._client._client, QdrantLocal):
-            self._client._client._load()

+ 0 - 62
api/core/vector_store/vector_store.py

@@ -1,62 +0,0 @@
-from flask import Flask
-from llama_index import ServiceContext, GPTVectorStoreIndex
-from requests import ReadTimeout
-from tenacity import retry, retry_if_exception_type, stop_after_attempt
-
-from core.vector_store.qdrant_vector_store_client import QdrantVectorStoreClient
-from core.vector_store.weaviate_vector_store_client import WeaviateVectorStoreClient
-
-SUPPORTED_VECTOR_STORES = ['weaviate', 'qdrant']
-
-
-class VectorStore:
-
-    def __init__(self):
-        self._vector_store = None
-        self._client = None
-
-    def init_app(self, app: Flask):
-        if not app.config['VECTOR_STORE']:
-            return
-
-        self._vector_store = app.config['VECTOR_STORE']
-        if self._vector_store not in SUPPORTED_VECTOR_STORES:
-            raise ValueError(f"Vector store {self._vector_store} is not supported.")
-
-        if self._vector_store == 'weaviate':
-            self._client = WeaviateVectorStoreClient(
-                endpoint=app.config['WEAVIATE_ENDPOINT'],
-                api_key=app.config['WEAVIATE_API_KEY'],
-                grpc_enabled=app.config['WEAVIATE_GRPC_ENABLED'],
-                batch_size=app.config['WEAVIATE_BATCH_SIZE']
-            )
-        elif self._vector_store == 'qdrant':
-            self._client = QdrantVectorStoreClient(
-                url=app.config['QDRANT_URL'],
-                api_key=app.config['QDRANT_API_KEY'],
-                root_path=app.root_path
-            )
-
-        app.extensions['vector_store'] = self
-
-    @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
-    def get_index(self, service_context: ServiceContext, index_struct: dict) -> GPTVectorStoreIndex:
-        vector_store_config: dict = index_struct.get('vector_store')
-        index = self.get_client().get_index(
-            service_context=service_context,
-            config=vector_store_config
-        )
-
-        return index
-
-    def to_index_struct(self, index_id: str) -> dict:
-        return {
-            "type": self._vector_store,
-            "vector_store": self.get_client().to_index_config(index_id)
-        }
-
-    def get_client(self):
-        if not self._client:
-            raise Exception("Vector store client is not initialized.")
-
-        return self._client

+ 0 - 66
api/core/vector_store/vector_store_index_query.py

@@ -1,66 +0,0 @@
-from llama_index.indices.query.base import IS
-from typing import (
-    Any,
-    Dict,
-    List,
-    Optional
-)
-
-from llama_index.docstore import BaseDocumentStore
-from llama_index.indices.postprocessor.node import (
-    BaseNodePostprocessor,
-)
-from llama_index.indices.vector_store import GPTVectorStoreIndexQuery
-from llama_index.indices.response.response_builder import ResponseMode
-from llama_index.indices.service_context import ServiceContext
-from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
-from llama_index.prompts.prompts import (
-    QuestionAnswerPrompt,
-    RefinePrompt,
-    SimpleInputPrompt,
-)
-
-from core.index.query.synthesizer import EnhanceResponseSynthesizer
-
-
-class EnhanceGPTVectorStoreIndexQuery(GPTVectorStoreIndexQuery):
-    @classmethod
-    def from_args(
-            cls,
-            index_struct: IS,
-            service_context: ServiceContext,
-            docstore: Optional[BaseDocumentStore] = None,
-            node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
-            verbose: bool = False,
-            # response synthesizer args
-            response_mode: ResponseMode = ResponseMode.DEFAULT,
-            text_qa_template: Optional[QuestionAnswerPrompt] = None,
-            refine_template: Optional[RefinePrompt] = None,
-            simple_template: Optional[SimpleInputPrompt] = None,
-            response_kwargs: Optional[Dict] = None,
-            use_async: bool = False,
-            streaming: bool = False,
-            optimizer: Optional[BaseTokenUsageOptimizer] = None,
-            # class-specific args
-            **kwargs: Any,
-    ) -> "BaseGPTIndexQuery":
-        response_synthesizer = EnhanceResponseSynthesizer.from_args(
-            service_context=service_context,
-            text_qa_template=text_qa_template,
-            refine_template=refine_template,
-            simple_template=simple_template,
-            response_mode=response_mode,
-            response_kwargs=response_kwargs,
-            use_async=use_async,
-            streaming=streaming,
-            optimizer=optimizer,
-        )
-        return cls(
-            index_struct=index_struct,
-            service_context=service_context,
-            response_synthesizer=response_synthesizer,
-            docstore=docstore,
-            node_postprocessors=node_postprocessors,
-            verbose=verbose,
-            **kwargs,
-        )

+ 38 - 0
api/core/vector_store/weaviate_vector_store.py

@@ -0,0 +1,38 @@
+from langchain.vectorstores import Weaviate
+
+
+class WeaviateVectorStore(Weaviate):
+    def del_texts(self, where_filter: dict):
+        if not where_filter:
+            raise ValueError('where_filter must not be empty')
+
+        self._client.batch.delete_objects(
+            class_name=self._index_name,
+            where=where_filter,
+            output='minimal'
+        )
+
+    def del_text(self, uuid: str) -> None:
+        self._client.data_object.delete(
+            uuid,
+            class_name=self._index_name
+        )
+
+    def text_exists(self, uuid: str) -> bool:
+        result = self._client.query.get(self._index_name).with_additional(["id"]).with_where({
+            "path": ["doc_id"],
+            "operator": "Equal",
+            "valueText": uuid,
+        }).with_limit(1).do()
+
+        if "errors" in result:
+            raise ValueError(f"Error during query: {result['errors']}")
+
+        entries = result["data"]["Get"][self._index_name]
+        if len(entries) == 0:
+            return False
+
+        return True
+
+    def delete(self):
+        self._client.schema.delete_class(self._index_name)

+ 0 - 270
api/core/vector_store/weaviate_vector_store_client.py

@@ -1,270 +0,0 @@
-import json
-import weaviate
-from dataclasses import field
-from typing import List, Any, Dict, Optional
-
-from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore
-from llama_index import ServiceContext, GPTWeaviateIndex, GPTVectorStoreIndex
-from llama_index.data_structs.data_structs_v2 import WeaviateIndexDict, Node
-from llama_index.data_structs.node_v2 import DocumentRelationship
-from llama_index.readers.weaviate.client import _class_name, NODE_SCHEMA, _logger
-from llama_index.vector_stores import WeaviateVectorStore
-from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult, VectorStoreQueryMode
-from llama_index.readers.weaviate.utils import (
-    parse_get_response,
-    validate_client,
-)
-
-
-class WeaviateVectorStoreClient(BaseVectorStoreClient):
-
-    def __init__(self, endpoint: str, api_key: str, grpc_enabled: bool, batch_size: int):
-        self._client = self.init_from_config(endpoint, api_key, grpc_enabled, batch_size)
-
-    def init_from_config(self, endpoint: str, api_key: str, grpc_enabled: bool, batch_size: int):
-        auth_config = weaviate.auth.AuthApiKey(api_key=api_key)
-
-        weaviate.connect.connection.has_grpc = grpc_enabled
-
-        client = weaviate.Client(
-            url=endpoint,
-            auth_client_secret=auth_config,
-            timeout_config=(5, 60),
-            startup_period=None
-        )
-
-        client.batch.configure(
-            # `batch_size` takes an `int` value to enable auto-batching
-            # (`None` is used for manual batching)
-            batch_size=batch_size,
-            # dynamically update the `batch_size` based on import speed
-            dynamic=True,
-            # `timeout_retries` takes an `int` value to retry on time outs
-            timeout_retries=3,
-        )
-
-        return client
-
-    def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
-        index_struct = WeaviateIndexDict()
-
-        if self._client is None:
-            raise Exception("Vector client is not initialized.")
-
-        # {"class_prefix": "Gpt_index_xxx"}
-        class_prefix = config.get('class_prefix')
-        if not class_prefix:
-            raise Exception("class_prefix cannot be None.")
-
-        return GPTWeaviateEnhanceIndex(
-            service_context=service_context,
-            index_struct=index_struct,
-            vector_store=WeaviateWithSimilaritiesVectorStore(
-                weaviate_client=self._client,
-                class_prefix=class_prefix
-            )
-        )
-
-    def to_index_config(self, index_id: str) -> dict:
-        return {"class_prefix": index_id}
-
-
-class WeaviateWithSimilaritiesVectorStore(WeaviateVectorStore, EnhanceVectorStore):
-    def query(self, query: VectorStoreQuery) -> VectorStoreQueryResult:
-        """Query index for top k most similar nodes."""
-        nodes = self.weaviate_query(
-            self._client,
-            self._class_prefix,
-            query,
-        )
-        nodes = nodes[: query.similarity_top_k]
-        node_idxs = [str(i) for i in range(len(nodes))]
-
-        similarities = []
-        for node in nodes:
-            similarities.append(node.extra_info['similarity'])
-            del node.extra_info['similarity']
-
-        return VectorStoreQueryResult(nodes=nodes, ids=node_idxs, similarities=similarities)
-
-    def weaviate_query(
-            self,
-            client: Any,
-            class_prefix: str,
-            query_spec: VectorStoreQuery,
-    ) -> List[Node]:
-        """Convert to LlamaIndex list."""
-        validate_client(client)
-
-        class_name = _class_name(class_prefix)
-        prop_names = [p["name"] for p in NODE_SCHEMA]
-        vector = query_spec.query_embedding
-
-        # build query
-        query = client.query.get(class_name, prop_names).with_additional(["id", "vector", "certainty"])
-        if query_spec.mode == VectorStoreQueryMode.DEFAULT:
-            _logger.debug("Using vector search")
-            if vector is not None:
-                query = query.with_near_vector(
-                    {
-                        "vector": vector,
-                    }
-                )
-        elif query_spec.mode == VectorStoreQueryMode.HYBRID:
-            _logger.debug(f"Using hybrid search with alpha {query_spec.alpha}")
-            query = query.with_hybrid(
-                query=query_spec.query_str,
-                alpha=query_spec.alpha,
-                vector=vector,
-            )
-        query = query.with_limit(query_spec.similarity_top_k)
-        _logger.debug(f"Using limit of {query_spec.similarity_top_k}")
-
-        # execute query
-        query_result = query.do()
-
-        # parse results
-        parsed_result = parse_get_response(query_result)
-        entries = parsed_result[class_name]
-        results = [self._to_node(entry) for entry in entries]
-        return results
-
-    def _to_node(self, entry: Dict) -> Node:
-        """Convert to Node."""
-        extra_info_str = entry["extra_info"]
-        if extra_info_str == "":
-            extra_info = None
-        else:
-            extra_info = json.loads(extra_info_str)
-
-        if 'certainty' in entry['_additional']:
-            if extra_info:
-                extra_info['similarity'] = entry['_additional']['certainty']
-            else:
-                extra_info = {'similarity': entry['_additional']['certainty']}
-
-        node_info_str = entry["node_info"]
-        if node_info_str == "":
-            node_info = None
-        else:
-            node_info = json.loads(node_info_str)
-
-        relationships_str = entry["relationships"]
-        relationships: Dict[DocumentRelationship, str]
-        if relationships_str == "":
-            relationships = field(default_factory=dict)
-        else:
-            relationships = {
-                DocumentRelationship(k): v for k, v in json.loads(relationships_str).items()
-            }
-
-        return Node(
-            text=entry["text"],
-            doc_id=entry["doc_id"],
-            embedding=entry["_additional"]["vector"],
-            extra_info=extra_info,
-            node_info=node_info,
-            relationships=relationships,
-        )
-
-    def delete(self, doc_id: str, **delete_kwargs: Any) -> None:
-        """Delete a document.
-
-        Args:
-            doc_id (str): document id
-
-        """
-        delete_document(self._client, doc_id, self._class_prefix)
-
-    def delete_node(self, node_id: str):
-        """
-        Delete node from the index.
-
-        :param node_id: node id
-        """
-        delete_node(self._client, node_id, self._class_prefix)
-
-    def exists_by_node_id(self, node_id: str) -> bool:
-        """
-        Get node from the index by node id.
-
-        :param node_id: node id
-        """
-        entry = get_by_node_id(self._client, node_id, self._class_prefix)
-        return True if entry else False
-
-
-class GPTWeaviateEnhanceIndex(GPTWeaviateIndex, BaseGPTVectorStoreIndex):
-    pass
-
-
-def delete_document(client: Any, ref_doc_id: str, class_prefix: str) -> None:
-    """Delete entry."""
-    validate_client(client)
-    # make sure that each entry
-    class_name = _class_name(class_prefix)
-    where_filter = {
-        "path": ["ref_doc_id"],
-        "operator": "Equal",
-        "valueString": ref_doc_id,
-    }
-    query = (
-        client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
-    )
-
-    query_result = query.do()
-    parsed_result = parse_get_response(query_result)
-    entries = parsed_result[class_name]
-    for entry in entries:
-        client.data_object.delete(entry["_additional"]["id"], class_name)
-
-    while len(entries) > 0:
-        query_result = query.do()
-        parsed_result = parse_get_response(query_result)
-        entries = parsed_result[class_name]
-        for entry in entries:
-            client.data_object.delete(entry["_additional"]["id"], class_name)
-
-
-def delete_node(client: Any, node_id: str, class_prefix: str) -> None:
-    """Delete entry."""
-    validate_client(client)
-    # make sure that each entry
-    class_name = _class_name(class_prefix)
-    where_filter = {
-        "path": ["doc_id"],
-        "operator": "Equal",
-        "valueString": node_id,
-    }
-    query = (
-        client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
-    )
-
-    query_result = query.do()
-    parsed_result = parse_get_response(query_result)
-    entries = parsed_result[class_name]
-    for entry in entries:
-        client.data_object.delete(entry["_additional"]["id"], class_name)
-
-
-def get_by_node_id(client: Any, node_id: str, class_prefix: str) -> Optional[Dict]:
-    """Delete entry."""
-    validate_client(client)
-    # make sure that each entry
-    class_name = _class_name(class_prefix)
-    where_filter = {
-        "path": ["doc_id"],
-        "operator": "Equal",
-        "valueString": node_id,
-    }
-    query = (
-        client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
-    )
-
-    query_result = query.do()
-    parsed_result = parse_get_response(query_result)
-    entries = parsed_result[class_name]
-    if len(entries) == 0:
-        return None
-
-    return entries[0]

+ 0 - 7
api/extensions/ext_vector_store.py

@@ -1,7 +0,0 @@
-from core.vector_store.vector_store import VectorStore
-
-vector_store = VectorStore()
-
-
-def init_app(app):
-    vector_store.init_app(app)

+ 6 - 0
api/libs/helper.py

@@ -3,6 +3,7 @@ import re
 import subprocess
 import uuid
 from datetime import datetime
+from hashlib import sha256
 from zoneinfo import available_timezones
 import random
 import string
@@ -147,3 +148,8 @@ def get_remote_ip(request):
         return request.headers.getlist("X-Forwarded-For")[0]
     else:
         return request.remote_addr
+
+
+def generate_text_hash(text: str) -> str:
+    hash_text = str(text) + 'None'
+    return sha256(hash_text.encode()).hexdigest()

+ 0 - 2
api/models/account.py

@@ -38,8 +38,6 @@ class Account(UserMixin, db.Model):
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
 
-    _current_tenant: db.Model = None
-
     @property
     def current_tenant(self):
         return self._current_tenant

+ 30 - 2
api/models/dataset.py

@@ -66,6 +66,23 @@ class Dataset(db.Model):
     def document_count(self):
         return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar()
 
+    @property
+    def available_document_count(self):
+        return db.session.query(func.count(Document.id)).filter(
+            Document.dataset_id == self.id,
+            Document.indexing_status == 'completed',
+            Document.enabled == True,
+            Document.archived == False
+        ).scalar()
+
+    @property
+    def available_segment_count(self):
+        return db.session.query(func.count(DocumentSegment.id)).filter(
+            DocumentSegment.dataset_id == self.id,
+            DocumentSegment.status == 'completed',
+            DocumentSegment.enabled == True
+        ).scalar()
+
     @property
     def word_count(self):
         return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \
@@ -260,7 +277,7 @@ class Document(db.Model):
 
     @property
     def dataset(self):
-        return Dataset.query.get(self.dataset_id)
+        return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).one_or_none()
 
     @property
     def segment_count(self):
@@ -395,7 +412,18 @@ class DatasetKeywordTable(db.Model):
 
     @property
     def keyword_table_dict(self):
-        return json.loads(self.keyword_table) if self.keyword_table else None
+        class SetDecoder(json.JSONDecoder):
+            def __init__(self, *args, **kwargs):
+                super().__init__(object_hook=self.object_hook, *args, **kwargs)
+
+            def object_hook(self, dct):
+                if isinstance(dct, dict):
+                    for keyword, node_idxs in dct.items():
+                        if isinstance(node_idxs, list):
+                            dct[keyword] = set(node_idxs)
+                return dct
+
+        return json.loads(self.keyword_table, cls=SetDecoder) if self.keyword_table else None
 
 
 class Embedding(db.Model):

+ 5 - 4
api/requirements.txt

@@ -2,6 +2,7 @@ coverage~=7.2.4
 beautifulsoup4==4.12.2
 flask~=2.3.2
 Flask-SQLAlchemy~=3.0.3
+SQLAlchemy~=1.4.28
 flask-login==0.6.2
 flask-migrate~=4.0.4
 flask-restful==0.3.9
@@ -9,8 +10,7 @@ flask-session2==1.3.1
 flask-cors==3.0.10
 gunicorn~=20.1.0
 gevent~=22.10.2
-langchain==0.0.142
-llama-index==0.5.27
+langchain==0.0.209
 openai~=0.27.5
 psycopg2-binary~=2.9.6
 pycryptodome==3.17
@@ -29,6 +29,7 @@ sentry-sdk[flask]~=1.21.1
 jieba==0.42.1
 celery==5.2.7
 redis~=4.5.4
-pypdf==3.8.1
 openpyxl==3.1.2
-chardet~=5.1.0
+chardet~=5.1.0
+docx2txt==0.8
+pypdfium2==4.16.0

+ 0 - 1
api/services/app_model_config_service.py

@@ -4,7 +4,6 @@ import uuid
 from core.constant import llm_constant
 from models.account import Account
 from services.dataset_service import DatasetService
-from services.errors.account import NoPermissionError
 
 
 class AppModelConfigService:

+ 0 - 3
api/services/dataset_service.py

@@ -7,7 +7,6 @@ from typing import Optional, List
 from extensions.ext_redis import redis_client
 from flask_login import current_user
 
-from core.index.index_builder import IndexBuilder
 from events.dataset_event import dataset_was_deleted
 from events.document_event import document_was_deleted
 from extensions.ext_database import db
@@ -386,8 +385,6 @@ class DocumentService:
 
             dataset.indexing_technique = document_data["indexing_technique"]
 
-        if dataset.indexing_technique == 'high_quality':
-            IndexBuilder.get_default_service_context(dataset.tenant_id)
         documents = []
         batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
         if 'original_document_id' in document_data and document_data["original_document_id"]:

+ 44 - 36
api/services/hit_testing_service.py

@@ -3,47 +3,56 @@ import time
 from typing import List
 
 import numpy as np
-from llama_index.data_structs.node_v2 import NodeWithScore
-from llama_index.indices.query.schema import QueryBundle
-from llama_index.indices.vector_store import GPTVectorStoreIndexQuery
+from flask import current_app
+from langchain.embeddings import OpenAIEmbeddings
+from langchain.embeddings.base import Embeddings
+from langchain.schema import Document
 from sklearn.manifold import TSNE
 
-from core.docstore.empty_docstore import EmptyDocumentStore
-from core.index.vector_index import VectorIndex
+from core.embedding.cached_embedding import CacheEmbedding
+from core.index.vector_index.vector_index import VectorIndex
+from core.llm.llm_builder import LLMBuilder
 from extensions.ext_database import db
 from models.account import Account
 from models.dataset import Dataset, DocumentSegment, DatasetQuery
-from services.errors.index import IndexNotInitializedError
 
 
 class HitTestingService:
     @classmethod
     def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict:
-        index = VectorIndex(dataset=dataset).query_index
-
-        if not index:
-            raise IndexNotInitializedError()
-
-        index_query = GPTVectorStoreIndexQuery(
-            index_struct=index.index_struct,
-            service_context=index.service_context,
-            vector_store=index.query_context.get('vector_store'),
-            docstore=EmptyDocumentStore(),
-            response_synthesizer=None,
-            similarity_top_k=limit
-        )
+        if dataset.available_document_count == 0 or dataset.available_document_count == 0:
+            return {
+                "query": {
+                    "content": query,
+                    "tsne_position": {'x': 0, 'y': 0},
+                },
+                "records": []
+            }
 
-        query_bundle = QueryBundle(
-            query_str=query,
-            custom_embedding_strs=[query],
+        model_credentials = LLMBuilder.get_model_credentials(
+            tenant_id=dataset.tenant_id,
+            model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
+            model_name='text-embedding-ada-002'
         )
 
-        query_bundle.embedding = index.service_context.embed_model.get_agg_embedding_from_queries(
-            query_bundle.embedding_strs
+        embeddings = CacheEmbedding(OpenAIEmbeddings(
+            **model_credentials
+        ))
+
+        vector_index = VectorIndex(
+            dataset=dataset,
+            config=current_app.config,
+            embeddings=embeddings
         )
 
         start = time.perf_counter()
-        nodes = index_query.retrieve(query_bundle=query_bundle)
+        documents = vector_index.search(
+            query,
+            search_type='similarity_score_threshold',
+            search_kwargs={
+                'k': 10
+            }
+        )
         end = time.perf_counter()
         logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
 
@@ -58,25 +67,24 @@ class HitTestingService:
         db.session.add(dataset_query)
         db.session.commit()
 
-        return cls.compact_retrieve_response(dataset, query_bundle, nodes)
+        return cls.compact_retrieve_response(dataset, embeddings, query, documents)
 
     @classmethod
-    def compact_retrieve_response(cls, dataset: Dataset, query_bundle: QueryBundle, nodes: List[NodeWithScore]):
-        embeddings = [
-            query_bundle.embedding
+    def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: List[Document]):
+        text_embeddings = [
+            embeddings.embed_query(query)
         ]
 
-        for node in nodes:
-            embeddings.append(node.node.embedding)
+        text_embeddings.extend(embeddings.embed_documents([document.page_content for document in documents]))
 
-        tsne_position_data = cls.get_tsne_positions_from_embeddings(embeddings)
+        tsne_position_data = cls.get_tsne_positions_from_embeddings(text_embeddings)
 
         query_position = tsne_position_data.pop(0)
 
         i = 0
         records = []
-        for node in nodes:
-            index_node_id = node.node.doc_id
+        for document in documents:
+            index_node_id = document.metadata['doc_id']
 
             segment = db.session.query(DocumentSegment).filter(
                 DocumentSegment.dataset_id == dataset.id,
@@ -91,7 +99,7 @@ class HitTestingService:
 
             record = {
                 "segment": segment,
-                "score": node.score,
+                "score": document.metadata['score'],
                 "tsne_position": tsne_position_data[i]
             }
 
@@ -101,7 +109,7 @@ class HitTestingService:
 
         return {
             "query": {
-                "content": query_bundle.query_str,
+                "content": query,
                 "tsne_position": query_position,
             },
             "records": records

+ 33 - 48
api/tasks/add_document_to_index_task.py

@@ -4,96 +4,81 @@ import time
 
 import click
 from celery import shared_task
-from llama_index.data_structs import Node
-from llama_index.data_structs.node_v2 import DocumentRelationship
+from langchain.schema import Document
 from werkzeug.exceptions import NotFound
 
-from core.index.keyword_table_index import KeywordTableIndex
-from core.index.vector_index import VectorIndex
+from core.index.index import IndexBuilder
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
-from models.dataset import DocumentSegment, Document
+from models.dataset import DocumentSegment
+from models.dataset import Document as DatasetDocument
 
 
 @shared_task
-def add_document_to_index_task(document_id: str):
+def add_document_to_index_task(dataset_document_id: str):
     """
     Async Add document to index
     :param document_id:
 
     Usage: add_document_to_index.delay(document_id)
     """
-    logging.info(click.style('Start add document to index: {}'.format(document_id), fg='green'))
+    logging.info(click.style('Start add document to index: {}'.format(dataset_document_id), fg='green'))
     start_at = time.perf_counter()
 
-    document = db.session.query(Document).filter(Document.id == document_id).first()
-    if not document:
+    dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document_id).first()
+    if not dataset_document:
         raise NotFound('Document not found')
 
-    if document.indexing_status != 'completed':
+    if dataset_document.indexing_status != 'completed':
         return
 
-    indexing_cache_key = 'document_{}_indexing'.format(document.id)
+    indexing_cache_key = 'document_{}_indexing'.format(dataset_document.id)
 
     try:
         segments = db.session.query(DocumentSegment).filter(
-            DocumentSegment.document_id == document.id,
+            DocumentSegment.document_id == dataset_document.id,
             DocumentSegment.enabled == True
         ) \
             .order_by(DocumentSegment.position.asc()).all()
 
-        nodes = []
-        previous_node = None
+        documents = []
         for segment in segments:
-            relationships = {
-                DocumentRelationship.SOURCE: document.id
-            }
-
-            if previous_node:
-                relationships[DocumentRelationship.PREVIOUS] = previous_node.doc_id
-
-                previous_node.relationships[DocumentRelationship.NEXT] = segment.index_node_id
-
-            node = Node(
-                doc_id=segment.index_node_id,
-                doc_hash=segment.index_node_hash,
-                text=segment.content,
-                extra_info=None,
-                node_info=None,
-                relationships=relationships
+            document = Document(
+                page_content=segment.content,
+                metadata={
+                    "doc_id": segment.index_node_id,
+                    "doc_hash": segment.index_node_hash,
+                    "document_id": segment.document_id,
+                    "dataset_id": segment.dataset_id,
+                }
             )
 
-            previous_node = node
+            documents.append(document)
 
-            nodes.append(node)
-
-        dataset = document.dataset
+        dataset = dataset_document.dataset
 
         if not dataset:
             raise Exception('Document has no dataset')
 
-        vector_index = VectorIndex(dataset=dataset)
-        keyword_table_index = KeywordTableIndex(dataset=dataset)
-
         # save vector index
-        if dataset.indexing_technique == "high_quality":
-            vector_index.add_nodes(
-                nodes=nodes,
-                duplicate_check=True
-            )
+        index = IndexBuilder.get_index(dataset, 'high_quality')
+        if index:
+            index.add_texts(documents)
 
         # save keyword index
-        keyword_table_index.add_nodes(nodes)
+        index = IndexBuilder.get_index(dataset, 'economy')
+        if index:
+            index.add_texts(documents)
 
         end_at = time.perf_counter()
         logging.info(
-            click.style('Document added to index: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
+            click.style('Document added to index: {} latency: {}'.format(dataset_document.id, end_at - start_at), fg='green'))
     except Exception as e:
         logging.exception("add document to index failed")
-        document.enabled = False
-        document.disabled_at = datetime.datetime.utcnow()
-        document.status = 'error'
-        document.error = str(e)
+        dataset_document.enabled = False
+        dataset_document.disabled_at = datetime.datetime.utcnow()
+        dataset_document.status = 'error'
+        dataset_document.error = str(e)
         db.session.commit()
     finally:
         redis_client.delete(indexing_cache_key)

+ 27 - 32
api/tasks/add_segment_to_index_task.py

@@ -4,12 +4,10 @@ import time
 
 import click
 from celery import shared_task
-from llama_index.data_structs import Node
-from llama_index.data_structs.node_v2 import DocumentRelationship
+from langchain.schema import Document
 from werkzeug.exceptions import NotFound
 
-from core.index.keyword_table_index import KeywordTableIndex
-from core.index.vector_index import VectorIndex
+from core.index.index import IndexBuilder
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from models.dataset import DocumentSegment
@@ -36,44 +34,41 @@ def add_segment_to_index_task(segment_id: str):
     indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
 
     try:
-        relationships = {
-            DocumentRelationship.SOURCE: segment.document_id,
-        }
-
-        previous_segment = segment.previous_segment
-        if previous_segment:
-            relationships[DocumentRelationship.PREVIOUS] = previous_segment.index_node_id
-
-        next_segment = segment.next_segment
-        if next_segment:
-            relationships[DocumentRelationship.NEXT] = next_segment.index_node_id
-
-        node = Node(
-            doc_id=segment.index_node_id,
-            doc_hash=segment.index_node_hash,
-            text=segment.content,
-            extra_info=None,
-            node_info=None,
-            relationships=relationships
+        document = Document(
+            page_content=segment.content,
+            metadata={
+                "doc_id": segment.index_node_id,
+                "doc_hash": segment.index_node_hash,
+                "document_id": segment.document_id,
+                "dataset_id": segment.dataset_id,
+            }
         )
 
         dataset = segment.dataset
 
         if not dataset:
-            raise Exception('Segment has no dataset')
+            logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan'))
+            return
 
-        vector_index = VectorIndex(dataset=dataset)
-        keyword_table_index = KeywordTableIndex(dataset=dataset)
+        dataset_document = segment.document
+
+        if not dataset_document:
+            logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan'))
+            return
+
+        if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed':
+            logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan'))
+            return
 
         # save vector index
-        if dataset.indexing_technique == "high_quality":
-            vector_index.add_nodes(
-                nodes=[node],
-                duplicate_check=True
-            )
+        index = IndexBuilder.get_index(dataset, 'high_quality')
+        if index:
+            index.add_texts([document], duplicate_check=True)
 
         # save keyword index
-        keyword_table_index.add_nodes([node])
+        index = IndexBuilder.get_index(dataset, 'economy')
+        if index:
+            index.add_texts([document])
 
         end_at = time.perf_counter()
         logging.info(click.style('Segment added to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green'))

+ 13 - 20
api/tasks/clean_dataset_task.py

@@ -4,8 +4,7 @@ import time
 import click
 from celery import shared_task
 
-from core.index.keyword_table_index import KeywordTableIndex
-from core.index.vector_index import VectorIndex
+from core.index.index import IndexBuilder
 from extensions.ext_database import db
 from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, DatasetQuery, DatasetProcessRule, \
     AppDatasetJoin
@@ -33,29 +32,24 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
             index_struct=index_struct
         )
 
-        vector_index = VectorIndex(dataset=dataset)
-        keyword_table_index = KeywordTableIndex(dataset=dataset)
-
         documents = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
-        index_doc_ids = [document.id for document in documents]
         segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
-        index_node_ids = [segment.index_node_id for segment in segments]
 
-        # delete from vector index
-        if dataset.indexing_technique == "high_quality":
-            for index_doc_id in index_doc_ids:
-                try:
-                    vector_index.del_doc(index_doc_id)
-                except Exception:
-                    logging.exception("Delete doc index failed when dataset deleted.")
-                    continue
+        vector_index = IndexBuilder.get_index(dataset, 'high_quality')
+        kw_index = IndexBuilder.get_index(dataset, 'economy')
 
-        # delete from keyword index
-        if index_node_ids:
+        # delete from vector index
+        if vector_index:
             try:
-                keyword_table_index.del_nodes(index_node_ids)
+                vector_index.delete()
             except Exception:
-                logging.exception("Delete nodes index failed when dataset deleted.")
+                logging.exception("Delete doc index failed when dataset deleted.")
+
+        # delete from keyword index
+        try:
+            kw_index.delete()
+        except Exception:
+            logging.exception("Delete nodes index failed when dataset deleted.")
 
         for document in documents:
             db.session.delete(document)
@@ -63,7 +57,6 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
         for segment in segments:
             db.session.delete(segment)
 
-        db.session.query(DatasetKeywordTable).filter(DatasetKeywordTable.dataset_id == dataset_id).delete()
         db.session.query(DatasetProcessRule).filter(DatasetProcessRule.dataset_id == dataset_id).delete()
         db.session.query(DatasetQuery).filter(DatasetQuery.dataset_id == dataset_id).delete()
         db.session.query(AppDatasetJoin).filter(AppDatasetJoin.dataset_id == dataset_id).delete()

+ 7 - 6
api/tasks/clean_document_task.py

@@ -4,8 +4,7 @@ import time
 import click
 from celery import shared_task
 
-from core.index.keyword_table_index import KeywordTableIndex
-from core.index.vector_index import VectorIndex
+from core.index.index import IndexBuilder
 from extensions.ext_database import db
 from models.dataset import DocumentSegment, Dataset
 
@@ -28,21 +27,23 @@ def clean_document_task(document_id: str, dataset_id: str):
         if not dataset:
             raise Exception('Document has no dataset')
 
-        vector_index = VectorIndex(dataset=dataset)
-        keyword_table_index = KeywordTableIndex(dataset=dataset)
+        vector_index = IndexBuilder.get_index(dataset, 'high_quality')
+        kw_index = IndexBuilder.get_index(dataset, 'economy')
 
         segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
         index_node_ids = [segment.index_node_id for segment in segments]
 
         # delete from vector index
-        vector_index.del_nodes(index_node_ids)
+        if vector_index:
+            vector_index.delete_by_document_id(document_id)
 
         # delete from keyword index
         if index_node_ids:
-            keyword_table_index.del_nodes(index_node_ids)
+            kw_index.delete_by_ids(index_node_ids)
 
         for segment in segments:
             db.session.delete(segment)
+
         db.session.commit()
         end_at = time.perf_counter()
         logging.info(

+ 7 - 6
api/tasks/clean_notion_document_task.py

@@ -5,8 +5,7 @@ from typing import List
 import click
 from celery import shared_task
 
-from core.index.keyword_table_index import KeywordTableIndex
-from core.index.vector_index import VectorIndex
+from core.index.index import IndexBuilder
 from extensions.ext_database import db
 from models.dataset import DocumentSegment, Dataset, Document
 
@@ -29,22 +28,24 @@ def clean_notion_document_task(document_ids: List[str], dataset_id: str):
         if not dataset:
             raise Exception('Document has no dataset')
 
-        vector_index = VectorIndex(dataset=dataset)
-        keyword_table_index = KeywordTableIndex(dataset=dataset)
+        vector_index = IndexBuilder.get_index(dataset, 'high_quality')
+        kw_index = IndexBuilder.get_index(dataset, 'economy')
         for document_id in document_ids:
             document = db.session.query(Document).filter(
                 Document.id == document_id
             ).first()
             db.session.delete(document)
+
             segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
             index_node_ids = [segment.index_node_id for segment in segments]
 
             # delete from vector index
-            vector_index.del_nodes(index_node_ids)
+            if vector_index:
+                vector_index.delete_by_document_id(document_id)
 
             # delete from keyword index
             if index_node_ids:
-                keyword_table_index.del_nodes(index_node_ids)
+                kw_index.delete_by_ids(index_node_ids)
 
             for segment in segments:
                 db.session.delete(segment)

+ 36 - 36
api/tasks/deal_dataset_vector_index_task.py

@@ -3,10 +3,12 @@ import time
 
 import click
 from celery import shared_task
-from llama_index.data_structs.node_v2 import DocumentRelationship, Node
-from core.index.vector_index import VectorIndex
+from langchain.schema import Document
+
+from core.index.index import IndexBuilder
 from extensions.ext_database import db
-from models.dataset import DocumentSegment, Document, Dataset
+from models.dataset import DocumentSegment, Dataset
+from models.dataset import Document as DatasetDocument
 
 
 @shared_task
@@ -24,49 +26,47 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
         dataset = Dataset.query.filter_by(
             id=dataset_id
         ).first()
+
         if not dataset:
             raise Exception('Dataset not found')
-        documents = Document.query.filter_by(dataset_id=dataset_id).all()
-        if documents:
-            vector_index = VectorIndex(dataset=dataset)
-            for document in documents:
-                # delete from vector index
-                if action == "remove":
-                    vector_index.del_doc(document.id)
-                elif action == "add":
+
+        if action == "remove":
+            index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True)
+            index.delete()
+        elif action == "add":
+            dataset_documents = db.session.query(DatasetDocument).filter(
+                DatasetDocument.dataset_id == dataset_id,
+                DatasetDocument.indexing_status == 'completed',
+                DatasetDocument.enabled == True,
+                DatasetDocument.archived == False,
+            ).all()
+
+            if dataset_documents:
+                # save vector index
+                index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True)
+                for dataset_document in dataset_documents:
+                    # delete from vector index
                     segments = db.session.query(DocumentSegment).filter(
-                        DocumentSegment.document_id == document.id,
+                        DocumentSegment.document_id == dataset_document.id,
                         DocumentSegment.enabled == True
                     ) .order_by(DocumentSegment.position.asc()).all()
 
-                    nodes = []
-                    previous_node = None
+                    documents = []
                     for segment in segments:
-                        relationships = {
-                            DocumentRelationship.SOURCE: document.id
-                        }
-
-                        if previous_node:
-                            relationships[DocumentRelationship.PREVIOUS] = previous_node.doc_id
-
-                            previous_node.relationships[DocumentRelationship.NEXT] = segment.index_node_id
-
-                        node = Node(
-                            doc_id=segment.index_node_id,
-                            doc_hash=segment.index_node_hash,
-                            text=segment.content,
-                            extra_info=None,
-                            node_info=None,
-                            relationships=relationships
+                        document = Document(
+                            page_content=segment.content,
+                            metadata={
+                                "doc_id": segment.index_node_id,
+                                "doc_hash": segment.index_node_hash,
+                                "document_id": segment.document_id,
+                                "dataset_id": segment.dataset_id,
+                            }
                         )
 
-                        previous_node = node
-                        nodes.append(node)
+                        documents.append(document)
+
                     # save vector index
-                    vector_index.add_nodes(
-                        nodes=nodes,
-                        duplicate_check=True
-                    )
+                    index.add_texts(documents)
 
         end_at = time.perf_counter()
         logging.info(

+ 23 - 23
api/tasks/document_indexing_sync_task.py

@@ -6,11 +6,9 @@ import click
 from celery import shared_task
 from werkzeug.exceptions import NotFound
 
-from core.data_source.notion import NotionPageReader
-from core.index.keyword_table_index import KeywordTableIndex
-from core.index.vector_index import VectorIndex
+from core.data_loader.loader.notion import NotionLoader
+from core.index.index import IndexBuilder
 from core.indexing_runner import IndexingRunner, DocumentIsPausedException
-from core.llm.error import ProviderTokenNotInitError
 from extensions.ext_database import db
 from models.dataset import Document, Dataset, DocumentSegment
 from models.source import DataSourceBinding
@@ -43,6 +41,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
             raise ValueError("no notion page found")
         workspace_id = data_source_info['notion_workspace_id']
         page_id = data_source_info['notion_page_id']
+        page_type = data_source_info['type']
         page_edited_time = data_source_info['last_edited_time']
         data_source_binding = DataSourceBinding.query.filter(
             db.and_(
@@ -54,8 +53,16 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
         ).first()
         if not data_source_binding:
             raise ValueError('Data source binding not found.')
-        reader = NotionPageReader(integration_token=data_source_binding.access_token)
-        last_edited_time = reader.get_page_last_edited_time(page_id)
+
+        loader = NotionLoader(
+            notion_access_token=data_source_binding.access_token,
+            notion_workspace_id=workspace_id,
+            notion_obj_id=page_id,
+            notion_page_type=page_type
+        )
+
+        last_edited_time = loader.get_notion_last_edited_time()
+
         # check the page is updated
         if last_edited_time != page_edited_time:
             document.indexing_status = 'parsing'
@@ -68,18 +75,19 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
                 if not dataset:
                     raise Exception('Dataset not found')
 
-                vector_index = VectorIndex(dataset=dataset)
-                keyword_table_index = KeywordTableIndex(dataset=dataset)
+                vector_index = IndexBuilder.get_index(dataset, 'high_quality')
+                kw_index = IndexBuilder.get_index(dataset, 'economy')
 
                 segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
                 index_node_ids = [segment.index_node_id for segment in segments]
 
                 # delete from vector index
-                vector_index.del_nodes(index_node_ids)
+                if vector_index:
+                    vector_index.delete_by_document_id(document_id)
 
                 # delete from keyword index
                 if index_node_ids:
-                    keyword_table_index.del_nodes(index_node_ids)
+                    kw_index.delete_by_ids(index_node_ids)
 
                 for segment in segments:
                     db.session.delete(segment)
@@ -89,21 +97,13 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
                     click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green'))
             except Exception:
                 logging.exception("Cleaned document when document update data source or process rule failed")
+
             try:
                 indexing_runner = IndexingRunner()
                 indexing_runner.run([document])
                 end_at = time.perf_counter()
                 logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
-            except DocumentIsPausedException:
-                logging.info(click.style('Document update paused, document id: {}'.format(document.id), fg='yellow'))
-            except ProviderTokenNotInitError as e:
-                document.indexing_status = 'error'
-                document.error = str(e.description)
-                document.stopped_at = datetime.datetime.utcnow()
-                db.session.commit()
-            except Exception as e:
-                logging.exception("consume update document failed")
-                document.indexing_status = 'error'
-                document.error = str(e)
-                document.stopped_at = datetime.datetime.utcnow()
-                db.session.commit()
+            except DocumentIsPausedException as ex:
+                logging.info(click.style(str(ex), fg='yellow'))
+            except Exception:
+                pass

+ 6 - 16
api/tasks/document_indexing_task.py

@@ -7,7 +7,6 @@ from celery import shared_task
 from werkzeug.exceptions import NotFound
 
 from core.indexing_runner import IndexingRunner, DocumentIsPausedException
-from core.llm.error import ProviderTokenNotInitError
 from extensions.ext_database import db
 from models.dataset import Document
 
@@ -22,9 +21,9 @@ def document_indexing_task(dataset_id: str, document_ids: list):
     Usage: document_indexing_task.delay(dataset_id, document_id)
     """
     documents = []
+    start_at = time.perf_counter()
     for document_id in document_ids:
         logging.info(click.style('Start process document: {}'.format(document_id), fg='green'))
-        start_at = time.perf_counter()
 
         document = db.session.query(Document).filter(
             Document.id == document_id,
@@ -44,17 +43,8 @@ def document_indexing_task(dataset_id: str, document_ids: list):
         indexing_runner = IndexingRunner()
         indexing_runner.run(documents)
         end_at = time.perf_counter()
-        logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
-    except DocumentIsPausedException:
-        logging.info(click.style('Document paused, document id: {}'.format(document.id), fg='yellow'))
-    except ProviderTokenNotInitError as e:
-        document.indexing_status = 'error'
-        document.error = str(e.description)
-        document.stopped_at = datetime.datetime.utcnow()
-        db.session.commit()
-    except Exception as e:
-        logging.exception("consume document failed")
-        document.indexing_status = 'error'
-        document.error = str(e)
-        document.stopped_at = datetime.datetime.utcnow()
-        db.session.commit()
+        logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green'))
+    except DocumentIsPausedException as ex:
+        logging.info(click.style(str(ex), fg='yellow'))
+    except Exception:
+        pass

+ 11 - 20
api/tasks/document_indexing_update_task.py

@@ -6,10 +6,8 @@ import click
 from celery import shared_task
 from werkzeug.exceptions import NotFound
 
-from core.index.keyword_table_index import KeywordTableIndex
-from core.index.vector_index import VectorIndex
+from core.index.index import IndexBuilder
 from core.indexing_runner import IndexingRunner, DocumentIsPausedException
-from core.llm.error import ProviderTokenNotInitError
 from extensions.ext_database import db
 from models.dataset import Document, Dataset, DocumentSegment
 
@@ -44,18 +42,19 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
         if not dataset:
             raise Exception('Dataset not found')
 
-        vector_index = VectorIndex(dataset=dataset)
-        keyword_table_index = KeywordTableIndex(dataset=dataset)
+        vector_index = IndexBuilder.get_index(dataset, 'high_quality')
+        kw_index = IndexBuilder.get_index(dataset, 'economy')
 
         segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
         index_node_ids = [segment.index_node_id for segment in segments]
 
         # delete from vector index
-        vector_index.del_nodes(index_node_ids)
+        if vector_index:
+            vector_index.delete_by_ids(index_node_ids)
 
         # delete from keyword index
         if index_node_ids:
-            keyword_table_index.del_nodes(index_node_ids)
+            kw_index.delete_by_ids(index_node_ids)
 
         for segment in segments:
             db.session.delete(segment)
@@ -65,21 +64,13 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
             click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green'))
     except Exception:
         logging.exception("Cleaned document when document update data source or process rule failed")
+
     try:
         indexing_runner = IndexingRunner()
         indexing_runner.run([document])
         end_at = time.perf_counter()
         logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
-    except DocumentIsPausedException:
-        logging.info(click.style('Document update paused, document id: {}'.format(document.id), fg='yellow'))
-    except ProviderTokenNotInitError as e:
-        document.indexing_status = 'error'
-        document.error = str(e.description)
-        document.stopped_at = datetime.datetime.utcnow()
-        db.session.commit()
-    except Exception as e:
-        logging.exception("consume update document failed")
-        document.indexing_status = 'error'
-        document.error = str(e)
-        document.stopped_at = datetime.datetime.utcnow()
-        db.session.commit()
+    except DocumentIsPausedException as ex:
+        logging.info(click.style(str(ex), fg='yellow'))
+    except Exception:
+        pass

+ 4 - 9
api/tasks/recover_document_indexing_task.py

@@ -1,4 +1,3 @@
-import datetime
 import logging
 import time
 
@@ -41,11 +40,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str):
             indexing_runner.run_in_indexing_status(document)
         end_at = time.perf_counter()
         logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
-    except DocumentIsPausedException:
-        logging.info(click.style('Document paused, document id: {}'.format(document.id), fg='yellow'))
-    except Exception as e:
-        logging.exception("consume document failed")
-        document.indexing_status = 'error'
-        document.error = str(e)
-        document.stopped_at = datetime.datetime.utcnow()
-        db.session.commit()
+    except DocumentIsPausedException as ex:
+        logging.info(click.style(str(ex), fg='yellow'))
+    except Exception:
+        pass

+ 5 - 6
api/tasks/remove_document_from_index_task.py

@@ -5,8 +5,7 @@ import click
 from celery import shared_task
 from werkzeug.exceptions import NotFound
 
-from core.index.keyword_table_index import KeywordTableIndex
-from core.index.vector_index import VectorIndex
+from core.index.index import IndexBuilder
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from models.dataset import DocumentSegment, Document
@@ -38,17 +37,17 @@ def remove_document_from_index_task(document_id: str):
         if not dataset:
             raise Exception('Document has no dataset')
 
-        vector_index = VectorIndex(dataset=dataset)
-        keyword_table_index = KeywordTableIndex(dataset=dataset)
+        vector_index = IndexBuilder.get_index(dataset, 'high_quality')
+        kw_index = IndexBuilder.get_index(dataset, 'economy')
 
         # delete from vector index
-        vector_index.del_doc(document.id)
+        vector_index.delete_by_document_id(document.id)
 
         # delete from keyword index
         segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).all()
         index_node_ids = [segment.index_node_id for segment in segments]
         if index_node_ids:
-            keyword_table_index.del_nodes(index_node_ids)
+            kw_index.delete_by_ids(index_node_ids)
 
         end_at = time.perf_counter()
         logging.info(

+ 18 - 8
api/tasks/remove_segment_from_index_task.py

@@ -5,8 +5,7 @@ import click
 from celery import shared_task
 from werkzeug.exceptions import NotFound
 
-from core.index.keyword_table_index import KeywordTableIndex
-from core.index.vector_index import VectorIndex
+from core.index.index import IndexBuilder
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from models.dataset import DocumentSegment
@@ -36,17 +35,28 @@ def remove_segment_from_index_task(segment_id: str):
         dataset = segment.dataset
 
         if not dataset:
-            raise Exception('Segment has no dataset')
+            logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan'))
+            return
 
-        vector_index = VectorIndex(dataset=dataset)
-        keyword_table_index = KeywordTableIndex(dataset=dataset)
+        dataset_document = segment.document
+
+        if not dataset_document:
+            logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan'))
+            return
+
+        if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed':
+            logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan'))
+            return
+
+        vector_index = IndexBuilder.get_index(dataset, 'high_quality')
+        kw_index = IndexBuilder.get_index(dataset, 'economy')
 
         # delete from vector index
-        if dataset.indexing_technique == "high_quality":
-            vector_index.del_nodes([segment.index_node_id])
+        if vector_index:
+            vector_index.delete_by_ids([segment.index_node_id])
 
         # delete from keyword index
-        keyword_table_index.del_nodes([segment.index_node_id])
+        kw_index.delete_by_ids([segment.index_node_id])
 
         end_at = time.perf_counter()
         logging.info(click.style('Segment removed from index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green'))