Explorar el Código

add mutil-thread document embedding (#3016)

Co-authored-by: jyong <jyong@dify.ai>
Jyong hace 1 año
padre
commit
b0b0cc045f

+ 33 - 16
api/core/indexing_runner.py

@@ -1,3 +1,4 @@
+import concurrent.futures
 import datetime
 import json
 import logging
@@ -650,17 +651,44 @@ class IndexingRunner:
         # chunk nodes by chunk size
         indexing_start_at = time.perf_counter()
         tokens = 0
-        chunk_size = 100
+        chunk_size = 10
 
         embedding_model_type_instance = None
         if embedding_model_instance:
             embedding_model_type_instance = embedding_model_instance.model_type_instance
             embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
+        with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
+            futures = []
+            for i in range(0, len(documents), chunk_size):
+                chunk_documents = documents[i:i + chunk_size]
+                futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor,
+                                               chunk_documents, dataset,
+                                               dataset_document, embedding_model_instance,
+                                               embedding_model_type_instance))
+
+            for future in futures:
+                tokens += future.result()
 
-        for i in range(0, len(documents), chunk_size):
+        indexing_end_at = time.perf_counter()
+
+        # update document status to completed
+        self._update_document_index_status(
+            document_id=dataset_document.id,
+            after_indexing_status="completed",
+            extra_update_params={
+                DatasetDocument.tokens: tokens,
+                DatasetDocument.completed_at: datetime.datetime.utcnow(),
+                DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
+            }
+        )
+
+    def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, dataset_document,
+                       embedding_model_instance, embedding_model_type_instance):
+        with flask_app.app_context():
             # check document is paused
             self._check_document_paused_status(dataset_document.id)
-            chunk_documents = documents[i:i + chunk_size]
+
+            tokens = 0
             if dataset.indexing_technique == 'high_quality' or embedding_model_type_instance:
                 tokens += sum(
                     embedding_model_type_instance.get_num_tokens(
@@ -670,9 +698,9 @@ class IndexingRunner:
                     )
                     for document in chunk_documents
                 )
+
             # load index
             index_processor.load(dataset, chunk_documents)
-            db.session.add(dataset)
 
             document_ids = [document.metadata['doc_id'] for document in chunk_documents]
             db.session.query(DocumentSegment).filter(
@@ -687,18 +715,7 @@ class IndexingRunner:
 
             db.session.commit()
 
-        indexing_end_at = time.perf_counter()
-
-        # update document status to completed
-        self._update_document_index_status(
-            document_id=dataset_document.id,
-            after_indexing_status="completed",
-            extra_update_params={
-                DatasetDocument.tokens: tokens,
-                DatasetDocument.completed_at: datetime.datetime.utcnow(),
-                DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
-            }
-        )
+            return tokens
 
     def _check_document_paused_status(self, document_id: str):
         indexing_cache_key = 'document_{}_is_paused'.format(document_id)

+ 1 - 1
api/core/rag/extractor/unstructured/unstructured_doc_extractor.py

@@ -53,7 +53,7 @@ class UnstructuredWordExtractor(BaseExtractor):
             elements = partition_docx(filename=self._file_path)
 
         from unstructured.chunking.title import chunk_by_title
-        chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
+        chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
         documents = []
         for chunk in chunks:
             text = chunk.text.strip()

+ 1 - 1
api/core/rag/extractor/unstructured/unstructured_eml_extractor.py

@@ -43,7 +43,7 @@ class UnstructuredEmailExtractor(BaseExtractor):
             pass
 
         from unstructured.chunking.title import chunk_by_title
-        chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
+        chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
         documents = []
         for chunk in chunks:
             text = chunk.text.strip()

+ 1 - 1
api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py

@@ -38,7 +38,7 @@ class UnstructuredMarkdownExtractor(BaseExtractor):
 
         elements = partition_md(filename=self._file_path, api_url=self._api_url)
         from unstructured.chunking.title import chunk_by_title
-        chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
+        chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
         documents = []
         for chunk in chunks:
             text = chunk.text.strip()

+ 1 - 1
api/core/rag/extractor/unstructured/unstructured_msg_extractor.py

@@ -28,7 +28,7 @@ class UnstructuredMsgExtractor(BaseExtractor):
 
         elements = partition_msg(filename=self._file_path, api_url=self._api_url)
         from unstructured.chunking.title import chunk_by_title
-        chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
+        chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
         documents = []
         for chunk in chunks:
             text = chunk.text.strip()

+ 1 - 1
api/core/rag/extractor/unstructured/unstructured_text_extractor.py

@@ -28,7 +28,7 @@ class UnstructuredTextExtractor(BaseExtractor):
 
         elements = partition_text(filename=self._file_path, api_url=self._api_url)
         from unstructured.chunking.title import chunk_by_title
-        chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
+        chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
         documents = []
         for chunk in chunks:
             text = chunk.text.strip()

+ 1 - 1
api/core/rag/extractor/unstructured/unstructured_xml_extractor.py

@@ -28,7 +28,7 @@ class UnstructuredXmlExtractor(BaseExtractor):
 
         elements = partition_xml(filename=self._file_path, xml_keep_tags=True, api_url=self._api_url)
         from unstructured.chunking.title import chunk_by_title
-        chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0)
+        chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
         documents = []
         for chunk in chunks:
             text = chunk.text.strip()