Browse Source

support images and tables extract from docx (#4619)

Jyong 11 months ago
parent
commit
233c4150d1

+ 2 - 2
api/core/indexing_runner.py

@@ -428,7 +428,7 @@ class IndexingRunner:
                 chunk_size=segmentation["max_tokens"],
                 chunk_overlap=chunk_overlap,
                 fixed_separator=separator,
-                separators=["\n\n", "。", ".", " ", ""],
+                separators=["\n\n", "。", ". ", " ", ""],
                 embedding_model_instance=embedding_model_instance
             )
         else:
@@ -436,7 +436,7 @@ class IndexingRunner:
             character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
                 chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
                 chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'],
-                separators=["\n\n", "。", ".", " ", ""],
+                separators=["\n\n", "。", ". ", " ", ""],
                 embedding_model_instance=embedding_model_instance
             )
 

+ 2 - 3
api/core/rag/extractor/extract_processor.py

@@ -16,7 +16,6 @@ from core.rag.extractor.markdown_extractor import MarkdownExtractor
 from core.rag.extractor.notion_extractor import NotionExtractor
 from core.rag.extractor.pdf_extractor import PdfExtractor
 from core.rag.extractor.text_extractor import TextExtractor
-from core.rag.extractor.unstructured.unstructured_doc_extractor import UnstructuredWordExtractor
 from core.rag.extractor.unstructured.unstructured_eml_extractor import UnstructuredEmailExtractor
 from core.rag.extractor.unstructured.unstructured_epub_extractor import UnstructuredEpubExtractor
 from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor
@@ -108,7 +107,7 @@ class ExtractProcessor:
                     elif file_extension in ['.htm', '.html']:
                         extractor = HtmlExtractor(file_path)
                     elif file_extension in ['.docx']:
-                        extractor = UnstructuredWordExtractor(file_path, unstructured_api_url)
+                        extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
                     elif file_extension == '.csv':
                         extractor = CSVExtractor(file_path, autodetect_encoding=True)
                     elif file_extension == '.msg':
@@ -137,7 +136,7 @@ class ExtractProcessor:
                     elif file_extension in ['.htm', '.html']:
                         extractor = HtmlExtractor(file_path)
                     elif file_extension in ['.docx']:
-                        extractor = WordExtractor(file_path)
+                        extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
                     elif file_extension == '.csv':
                         extractor = CSVExtractor(file_path, autodetect_encoding=True)
                     elif file_extension == 'epub':

+ 121 - 7
api/core/rag/extractor/word_extractor.py

@@ -1,12 +1,20 @@
 """Abstract interface for document loader implementations."""
+import datetime
+import mimetypes
 import os
 import tempfile
+import uuid
 from urllib.parse import urlparse
 
 import requests
+from docx import Document as DocxDocument
+from flask import current_app
 
 from core.rag.extractor.extractor_base import BaseExtractor
 from core.rag.models.document import Document
+from extensions.ext_database import db
+from extensions.ext_storage import storage
+from models.model import UploadFile
 
 
 class WordExtractor(BaseExtractor):
@@ -17,9 +25,12 @@ class WordExtractor(BaseExtractor):
         file_path: Path to the file to load.
     """
 
-    def __init__(self, file_path: str):
+    def __init__(self, file_path: str, tenant_id: str, user_id: str):
         """Initialize with file path."""
         self.file_path = file_path
+        self.tenant_id = tenant_id
+        self.user_id = user_id
+
         if "~" in self.file_path:
             self.file_path = os.path.expanduser(self.file_path)
 
@@ -45,12 +56,7 @@ class WordExtractor(BaseExtractor):
 
     def extract(self) -> list[Document]:
         """Load given path as single page."""
-        from docx import Document as docx_Document
-
-        document = docx_Document(self.file_path)
-        doc_texts = [paragraph.text for paragraph in document.paragraphs]
-        content = '\n'.join(doc_texts)
-
+        content = self.parse_docx(self.file_path, 'storage')
         return [Document(
             page_content=content,
             metadata={"source": self.file_path},
@@ -61,3 +67,111 @@ class WordExtractor(BaseExtractor):
         """Check if the url is valid."""
         parsed = urlparse(url)
         return bool(parsed.netloc) and bool(parsed.scheme)
+
+    def _extract_images_from_docx(self, doc, image_folder):
+        os.makedirs(image_folder, exist_ok=True)
+        image_count = 0
+        image_map = {}
+
+        for rel in doc.part.rels.values():
+            if "image" in rel.target_ref:
+                image_count += 1
+                image_ext = rel.target_ref.split('.')[-1]
+                # user uuid as file name
+                file_uuid = str(uuid.uuid4())
+                file_key = 'image_files/' + self.tenant_id + '/' + file_uuid + '.' + image_ext
+                mime_type, _ = mimetypes.guess_type(file_key)
+
+                storage.save(file_key, rel.target_part.blob)
+                # save file to db
+                config = current_app.config
+                upload_file = UploadFile(
+                    tenant_id=self.tenant_id,
+                    storage_type=config['STORAGE_TYPE'],
+                    key=file_key,
+                    name=file_key,
+                    size=0,
+                    extension=image_ext,
+                    mime_type=mime_type,
+                    created_by=self.user_id,
+                    created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
+                    used=True,
+                    used_by=self.user_id,
+                    used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
+                )
+
+                db.session.add(upload_file)
+                db.session.commit()
+                image_map[rel.target_part] = f"![image]({current_app.config.get('CONSOLE_API_URL')}/files/{upload_file.id}/image-preview)"
+
+        return image_map
+
+    def _table_to_markdown(self, table):
+        markdown = ""
+        # deal with table headers
+        header_row = table.rows[0]
+        headers = [cell.text for cell in header_row.cells]
+        markdown += "| " + " | ".join(headers) + " |\n"
+        markdown += "| " + " | ".join(["---"] * len(headers)) + " |\n"
+        # deal with table rows
+        for row in table.rows[1:]:
+            row_cells = [cell.text for cell in row.cells]
+            markdown += "| " + " | ".join(row_cells) + " |\n"
+
+        return markdown
+
+    def _parse_paragraph(self, paragraph, image_map):
+        paragraph_content = []
+        for run in paragraph.runs:
+            if run.element.xpath('.//a:blip'):
+                for blip in run.element.xpath('.//a:blip'):
+                    embed_id = blip.get('{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed')
+                    if embed_id:
+                        rel_target = run.part.rels[embed_id].target_ref
+                        if rel_target in image_map:
+                            paragraph_content.append(image_map[rel_target])
+            if run.text.strip():
+                paragraph_content.append(run.text.strip())
+        return ' '.join(paragraph_content) if paragraph_content else ''
+
+    def parse_docx(self, docx_path, image_folder):
+        doc = DocxDocument(docx_path)
+        os.makedirs(image_folder, exist_ok=True)
+
+        content = []
+
+        image_map = self._extract_images_from_docx(doc, image_folder)
+
+        def parse_paragraph(paragraph):
+            paragraph_content = []
+            for run in paragraph.runs:
+                if run.element.tag.endswith('r'):
+                    drawing_elements = run.element.findall(
+                        './/{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing')
+                    for drawing in drawing_elements:
+                        blip_elements = drawing.findall(
+                            './/{http://schemas.openxmlformats.org/drawingml/2006/main}blip')
+                        for blip in blip_elements:
+                            embed_id = blip.get(
+                                '{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed')
+                            if embed_id:
+                                image_part = doc.part.related_parts.get(embed_id)
+                                if image_part in image_map:
+                                    paragraph_content.append(image_map[image_part])
+                if run.text.strip():
+                    paragraph_content.append(run.text.strip())
+            return ''.join(paragraph_content) if paragraph_content else ''
+
+        paragraphs = doc.paragraphs.copy()
+        tables = doc.tables.copy()
+        for element in doc.element.body:
+            if element.tag.endswith('p'):  # paragraph
+                para = paragraphs.pop(0)
+                parsed_paragraph = parse_paragraph(para)
+                if parsed_paragraph:
+                    content.append(parsed_paragraph)
+            elif element.tag.endswith('tbl'):  # table
+                table = tables.pop(0)
+                content.append(self._table_to_markdown(table))
+        return '\n'.join(content)
+

+ 2 - 2
api/core/rag/index_processor/index_processor_base.py

@@ -57,7 +57,7 @@ class BaseIndexProcessor(ABC):
                 chunk_size=segmentation["max_tokens"],
                 chunk_overlap=segmentation.get('chunk_overlap', 0),
                 fixed_separator=separator,
-                separators=["\n\n", "。", ".", " ", ""],
+                separators=["\n\n", "。", ". ", " ", ""],
                 embedding_model_instance=embedding_model_instance
             )
         else:
@@ -65,7 +65,7 @@ class BaseIndexProcessor(ABC):
             character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
                 chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
                 chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'],
-                separators=["\n\n", "。", ".", " ", ""],
+                separators=["\n\n", "。", ". ", " ", ""],
                 embedding_model_instance=embedding_model_instance
             )
 

+ 2 - 2
api/core/rag/retrieval/dataset_retrieval.py

@@ -144,9 +144,9 @@ class DatasetRetrieval:
                                                                                        float('inf')))
             for segment in sorted_segments:
                 if segment.answer:
-                    document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
+                    document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}')
                 else:
-                    document_context_list.append(segment.content)
+                    document_context_list.append(segment.get_sign_content())
             if show_retrieve_source:
                 context_list = []
                 resource_number = 1

+ 1 - 1
api/core/splitter/text_splitter.py

@@ -94,7 +94,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
                 documents.append(new_doc)
         return documents
 
-    def split_documents(self, documents: Iterable[Document] ) -> list[Document]:
+    def split_documents(self, documents: Iterable[Document]) -> list[Document]:
         """Split documents."""
         texts, metadatas = [], []
         for doc in documents:

+ 2 - 2
api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py

@@ -99,9 +99,9 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
                                                                                        float('inf')))
             for segment in sorted_segments:
                 if segment.answer:
-                    document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
+                    document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}')
                 else:
-                    document_context_list.append(segment.content)
+                    document_context_list.append(segment.get_sign_content())
             if self.return_resource:
                 context_list = []
                 resource_number = 1

+ 2 - 2
api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py

@@ -105,9 +105,9 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
                                                                                            float('inf')))
                 for segment in sorted_segments:
                     if segment.answer:
-                        document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
+                        document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}')
                     else:
-                        document_context_list.append(segment.content)
+                        document_context_list.append(segment.get_sign_content())
                 if self.return_resource:
                     context_list = []
                     resource_number = 1

+ 2 - 2
api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py

@@ -191,9 +191,9 @@ class KnowledgeRetrievalNode(BaseNode):
                             'title': document.name
                         }
                         if segment.answer:
-                            source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
+                            source['content'] = f'question:{segment.get_sign_content()} \nanswer:{segment.answer}'
                         else:
-                            source['content'] = segment.content
+                            source['content'] = segment.get_sign_content()
                         context_list.append(source)
                         resource_number += 1
         return context_list

+ 27 - 0
api/models/dataset.py

@@ -1,8 +1,15 @@
+import base64
+import hashlib
+import hmac
 import json
 import logging
+import os
 import pickle
+import re
+import time
 from json import JSONDecodeError
 
+from flask import current_app
 from sqlalchemy import func
 from sqlalchemy.dialects.postgresql import JSONB
 
@@ -414,6 +421,26 @@ class DocumentSegment(db.Model):
             DocumentSegment.position == self.position + 1
         ).first()
 
+    def get_sign_content(self):
+        pattern = r"/files/([a-f0-9\-]+)/image-preview"
+        text = self.content
+        match = re.search(pattern, text)
+
+        if match:
+            upload_file_id = match.group(1)
+            nonce = os.urandom(16).hex()
+            timestamp = str(int(time.time()))
+            data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
+            secret_key = current_app.config['SECRET_KEY'].encode()
+            sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
+            encoded_sign = base64.urlsafe_b64encode(sign).decode()
+
+            params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
+            replacement = r"\g<0>?{params}".format(params=params)
+            text = re.sub(pattern, replacement, text)
+        return text
+
+
 
 class AppDatasetJoin(db.Model):
     __tablename__ = 'app_dataset_joins'