Jyong преди 1 година
родител
ревизия
174ebb51db
променени са 1 файла, в които са добавени 35 реда и са изтрити 36 реда
  1. 35 36
      api/core/indexing_runner.py

+ 35 - 36
api/core/indexing_runner.py

@@ -494,6 +494,7 @@ class IndexingRunner:
         Split the text documents into nodes.
         """
         all_documents = []
+        all_qa_documents = []
         for text_doc in text_docs:
             # document clean
             document_text = self._document_clean(text_doc.page_content, processing_rule)
@@ -502,58 +503,56 @@ class IndexingRunner:
             # parse document to nodes
             documents = splitter.split_documents([text_doc])
             split_documents = []
+            for document_node in documents:
+                doc_id = str(uuid.uuid4())
+                hash = helper.generate_text_hash(document_node.page_content)
+                document_node.metadata['doc_id'] = doc_id
+                document_node.metadata['doc_hash'] = hash
+
+                split_documents.append(document_node)
+            all_documents.extend(split_documents)
+        # processing qa document
+        if document_form == 'qa_model':
             llm: StreamableOpenAI = LLMBuilder.to_llm(
                 tenant_id=tenant_id,
                 model_name='gpt-3.5-turbo',
                 max_tokens=2000
             )
-            for i in range(0, len(documents), 10):
+            for i in range(0, len(all_documents), 10):
                 threads = []
-                sub_documents = documents[i:i + 10]
+                sub_documents = all_documents[i:i + 10]
                 for doc in sub_documents:
-                    document_format_thread = threading.Thread(target=self.format_document, kwargs={
-                        'llm': llm, 'document_node': doc, 'split_documents': split_documents,
-                        'document_form': document_form})
+                    document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={
+                        'llm': llm, 'document_node': doc, 'all_qa_documents': all_qa_documents})
                     threads.append(document_format_thread)
                     document_format_thread.start()
                 for thread in threads:
                     thread.join()
-
-            all_documents.extend(split_documents)
-
+            return all_qa_documents
         return all_documents
 
-    def format_document(self, llm: StreamableOpenAI, document_node, split_documents, document_form: str):
+    def format_qa_document(self, llm: StreamableOpenAI, document_node, all_qa_documents):
         format_documents = []
         if document_node.page_content is None or not document_node.page_content.strip():
-            return format_documents
-        if document_form == 'text_model':
-            # text model document
-            doc_id = str(uuid.uuid4())
-            hash = helper.generate_text_hash(document_node.page_content)
-
-            document_node.metadata['doc_id'] = doc_id
-            document_node.metadata['doc_hash'] = hash
+            return
+        try:
+            # qa model document
+            response = LLMGenerator.generate_qa_document_sync(llm, document_node.page_content)
+            document_qa_list = self.format_split_text(response)
+            qa_documents = []
+            for result in document_qa_list:
+                qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
+                doc_id = str(uuid.uuid4())
+                hash = helper.generate_text_hash(result['question'])
+                qa_document.metadata['answer'] = result['answer']
+                qa_document.metadata['doc_id'] = doc_id
+                qa_document.metadata['doc_hash'] = hash
+                qa_documents.append(qa_document)
+            format_documents.extend(qa_documents)
+        except Exception as e:
+            logging.error(str(e))
 
-            format_documents.append(document_node)
-        elif document_form == 'qa_model':
-            try:
-                # qa model document
-                response = LLMGenerator.generate_qa_document_sync(llm, document_node.page_content)
-                document_qa_list = self.format_split_text(response)
-                qa_documents = []
-                for result in document_qa_list:
-                    qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
-                    doc_id = str(uuid.uuid4())
-                    hash = helper.generate_text_hash(result['question'])
-                    qa_document.metadata['answer'] = result['answer']
-                    qa_document.metadata['doc_id'] = doc_id
-                    qa_document.metadata['doc_hash'] = hash
-                    qa_documents.append(qa_document)
-                format_documents.extend(qa_documents)
-            except Exception as e:
-                logging.error(str(e))
-        split_documents.extend(format_documents)
+        all_qa_documents.extend(format_documents)
 
 
     def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter,