浏览代码

fix keyword index error when storage source is S3 (#3182)

Jyong 1 年之前
父节点
当前提交
283979fc46
共有 2 个文件被更改,包括 114 次插入80 次删除
  1. 39 13
      api/core/indexing_runner.py
  2. 75 67
      api/core/rag/datasource/keyword/jieba/jieba.py

+ 39 - 13
api/core/indexing_runner.py

@@ -19,6 +19,7 @@ from core.model_manager import ModelInstance, ModelManager
 from core.model_runtime.entities.model_entities import ModelType, PriceType
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+from core.rag.datasource.keyword.keyword_factory import Keyword
 from core.rag.extractor.entity.extract_setting import ExtractSetting
 from core.rag.index_processor.index_processor_base import BaseIndexProcessor
 from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@@ -657,18 +658,25 @@ class IndexingRunner:
         if embedding_model_instance:
             embedding_model_type_instance = embedding_model_instance.model_type_instance
             embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
-        with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
-            futures = []
-            for i in range(0, len(documents), chunk_size):
-                chunk_documents = documents[i:i + chunk_size]
-                futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor,
-                                               chunk_documents, dataset,
-                                               dataset_document, embedding_model_instance,
-                                               embedding_model_type_instance))
-
-            for future in futures:
-                tokens += future.result()
-
+        # create keyword index
+        create_keyword_thread = threading.Thread(target=self._process_keyword_index,
+                                                 args=(current_app._get_current_object(),
+                                                       dataset, dataset_document, documents))
+        create_keyword_thread.start()
+        if dataset.indexing_technique == 'high_quality':
+            with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
+                futures = []
+                for i in range(0, len(documents), chunk_size):
+                    chunk_documents = documents[i:i + chunk_size]
+                    futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor,
+                                                   chunk_documents, dataset,
+                                                   dataset_document, embedding_model_instance,
+                                                   embedding_model_type_instance))
+
+                for future in futures:
+                    tokens += future.result()
+
+        create_keyword_thread.join()
         indexing_end_at = time.perf_counter()
 
         # update document status to completed
@@ -682,6 +690,24 @@ class IndexingRunner:
             }
         )
 
+    def _process_keyword_index(self, flask_app, dataset, dataset_document, documents):
+        with flask_app.app_context():
+            keyword = Keyword(dataset)
+            keyword.create(documents)
+            if dataset.indexing_technique != 'high_quality':
+                document_ids = [document.metadata['doc_id'] for document in documents]
+                db.session.query(DocumentSegment).filter(
+                    DocumentSegment.document_id == dataset_document.id,
+                    DocumentSegment.index_node_id.in_(document_ids),
+                    DocumentSegment.status == "indexing"
+                ).update({
+                    DocumentSegment.status: "completed",
+                    DocumentSegment.enabled: True,
+                    DocumentSegment.completed_at: datetime.datetime.utcnow()
+                })
+
+                db.session.commit()
+
     def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, dataset_document,
                        embedding_model_instance, embedding_model_type_instance):
         with flask_app.app_context():
@@ -700,7 +726,7 @@ class IndexingRunner:
                 )
 
             # load index
-            index_processor.load(dataset, chunk_documents)
+            index_processor.load(dataset, chunk_documents, with_keywords=False)
 
             document_ids = [document.metadata['doc_id'] for document in chunk_documents]
             db.session.query(DocumentSegment).filter(

+ 75 - 67
api/core/rag/datasource/keyword/jieba/jieba.py

@@ -24,56 +24,64 @@ class Jieba(BaseKeyword):
         self._config = KeywordTableConfig()
 
     def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
-        keyword_table_handler = JiebaKeywordTableHandler()
-        keyword_table = self._get_dataset_keyword_table()
-        for text in texts:
-            keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
-            self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
-            keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
+        lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
+        with redis_client.lock(lock_name, timeout=600):
+            keyword_table_handler = JiebaKeywordTableHandler()
+            keyword_table = self._get_dataset_keyword_table()
+            for text in texts:
+                keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
+                self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
+                keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
 
-        self._save_dataset_keyword_table(keyword_table)
+            self._save_dataset_keyword_table(keyword_table)
 
-        return self
+            return self
 
     def add_texts(self, texts: list[Document], **kwargs):
-        keyword_table_handler = JiebaKeywordTableHandler()
-
-        keyword_table = self._get_dataset_keyword_table()
-        keywords_list = kwargs.get('keywords_list', None)
-        for i in range(len(texts)):
-            text = texts[i]
-            if keywords_list:
-                keywords = keywords_list[i]
-            else:
-                keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
-            self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
-            keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
-
-        self._save_dataset_keyword_table(keyword_table)
+        lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
+        with redis_client.lock(lock_name, timeout=600):
+            keyword_table_handler = JiebaKeywordTableHandler()
+
+            keyword_table = self._get_dataset_keyword_table()
+            keywords_list = kwargs.get('keywords_list', None)
+            for i in range(len(texts)):
+                text = texts[i]
+                if keywords_list:
+                    keywords = keywords_list[i]
+                else:
+                    keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
+                self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
+                keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
+
+            self._save_dataset_keyword_table(keyword_table)
 
     def text_exists(self, id: str) -> bool:
         keyword_table = self._get_dataset_keyword_table()
         return id in set.union(*keyword_table.values())
 
     def delete_by_ids(self, ids: list[str]) -> None:
-        keyword_table = self._get_dataset_keyword_table()
-        keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
+        lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
+        with redis_client.lock(lock_name, timeout=600):
+            keyword_table = self._get_dataset_keyword_table()
+            keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
 
-        self._save_dataset_keyword_table(keyword_table)
+            self._save_dataset_keyword_table(keyword_table)
 
     def delete_by_document_id(self, document_id: str):
-        # get segment ids by document_id
-        segments = db.session.query(DocumentSegment).filter(
-            DocumentSegment.dataset_id == self.dataset.id,
-            DocumentSegment.document_id == document_id
-        ).all()
+        lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
+        with redis_client.lock(lock_name, timeout=600):
+            # get segment ids by document_id
+            segments = db.session.query(DocumentSegment).filter(
+                DocumentSegment.dataset_id == self.dataset.id,
+                DocumentSegment.document_id == document_id
+            ).all()
 
-        ids = [segment.index_node_id for segment in segments]
+            ids = [segment.index_node_id for segment in segments]
 
-        keyword_table = self._get_dataset_keyword_table()
-        keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
+            keyword_table = self._get_dataset_keyword_table()
+            keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
 
-        self._save_dataset_keyword_table(keyword_table)
+            self._save_dataset_keyword_table(keyword_table)
 
     def search(
             self, query: str,
@@ -106,13 +114,15 @@ class Jieba(BaseKeyword):
         return documents
 
     def delete(self) -> None:
-        dataset_keyword_table = self.dataset.dataset_keyword_table
-        if dataset_keyword_table:
-            db.session.delete(dataset_keyword_table)
-            db.session.commit()
-            if dataset_keyword_table.data_source_type != 'database':
-                file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt'
-                storage.delete(file_key)
+        lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
+        with redis_client.lock(lock_name, timeout=600):
+            dataset_keyword_table = self.dataset.dataset_keyword_table
+            if dataset_keyword_table:
+                db.session.delete(dataset_keyword_table)
+                db.session.commit()
+                if dataset_keyword_table.data_source_type != 'database':
+                    file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt'
+                    storage.delete(file_key)
 
     def _save_dataset_keyword_table(self, keyword_table):
         keyword_table_dict = {
@@ -135,33 +145,31 @@ class Jieba(BaseKeyword):
             storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode('utf-8'))
 
     def _get_dataset_keyword_table(self) -> Optional[dict]:
-        lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
-        with redis_client.lock(lock_name, timeout=20):
-            dataset_keyword_table = self.dataset.dataset_keyword_table
-            if dataset_keyword_table:
-                keyword_table_dict = dataset_keyword_table.keyword_table_dict
-                if keyword_table_dict:
-                    return keyword_table_dict['__data__']['table']
-            else:
-                keyword_data_source_type = current_app.config['KEYWORD_DATA_SOURCE_TYPE']
-                dataset_keyword_table = DatasetKeywordTable(
-                    dataset_id=self.dataset.id,
-                    keyword_table='',
-                    data_source_type=keyword_data_source_type,
-                )
-                if keyword_data_source_type == 'database':
-                    dataset_keyword_table.keyword_table = json.dumps({
-                        '__type__': 'keyword_table',
-                        '__data__': {
-                            "index_id": self.dataset.id,
-                            "summary": None,
-                            "table": {}
-                        }
-                    }, cls=SetEncoder)
-                db.session.add(dataset_keyword_table)
-                db.session.commit()
+        dataset_keyword_table = self.dataset.dataset_keyword_table
+        if dataset_keyword_table:
+            keyword_table_dict = dataset_keyword_table.keyword_table_dict
+            if keyword_table_dict:
+                return keyword_table_dict['__data__']['table']
+        else:
+            keyword_data_source_type = current_app.config['KEYWORD_DATA_SOURCE_TYPE']
+            dataset_keyword_table = DatasetKeywordTable(
+                dataset_id=self.dataset.id,
+                keyword_table='',
+                data_source_type=keyword_data_source_type,
+            )
+            if keyword_data_source_type == 'database':
+                dataset_keyword_table.keyword_table = json.dumps({
+                    '__type__': 'keyword_table',
+                    '__data__': {
+                        "index_id": self.dataset.id,
+                        "summary": None,
+                        "table": {}
+                    }
+                }, cls=SetEncoder)
+            db.session.add(dataset_keyword_table)
+            db.session.commit()
 
-            return {}
+        return {}
 
     def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict:
         for keyword in keywords: