瀏覽代碼

feat: support legacy doc (#2100)

crazywoola 1 年之前
父節點
當前提交
1f48e3d44a

+ 2 - 6
api/controllers/console/datasets/file.py

@@ -9,7 +9,7 @@ from flask import current_app, request
 from flask_login import current_user
 from flask_restful import Resource, marshal_with
 from libs.login import login_required
-from services.file_service import FileService
+from services.file_service import FileService, ALLOWED_EXTENSIONS, UNSTRUSTURED_ALLOWED_EXTENSIONS
 
 PREVIEW_WORDS_LIMIT = 3000
 
@@ -71,11 +71,7 @@ class FileSupportTypeApi(Resource):
     @account_initialization_required
     def get(self):
         etl_type = current_app.config['ETL_TYPE']
-        if etl_type == 'Unstructured':
-            allowed_extensions = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx',
-                                  'docx', 'csv', 'eml', 'msg', 'pptx', 'ppt', 'xml']
-        else:
-            allowed_extensions = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
+        allowed_extensions = UNSTRUSTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS
         return {'allowed_extensions': allowed_extensions}
 
 

+ 5 - 5
api/core/data_loader/file_extractor.py

@@ -27,7 +27,7 @@ USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTM
 
 class FileExtractor:
     @classmethod
-    def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[List[Document] | str]:
+    def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[List[Document], str]:
         with tempfile.TemporaryDirectory() as temp_dir:
             suffix = Path(upload_file.key).suffix
             file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
@@ -36,7 +36,7 @@ class FileExtractor:
             return cls.load_from_file(file_path, return_text, upload_file, is_automatic)
 
     @classmethod
-    def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document] | str]:
+    def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document], str]:
         response = requests.get(url, headers={
             "User-Agent": USER_AGENT
         })
@@ -52,7 +52,7 @@ class FileExtractor:
     @classmethod
     def load_from_file(cls, file_path: str, return_text: bool = False,
                        upload_file: Optional[UploadFile] = None,
-                       is_automatic: bool = False) -> Union[List[Document] | str]:
+                       is_automatic: bool = False) -> Union[List[Document], str]:
         input_file = Path(file_path)
         delimiter = '\n'
         file_extension = input_file.suffix.lower()
@@ -68,7 +68,7 @@ class FileExtractor:
                     else MarkdownLoader(file_path, autodetect_encoding=True)
             elif file_extension in ['.htm', '.html']:
                 loader = HTMLLoader(file_path)
-            elif file_extension == '.docx':
+            elif file_extension in ['.docx', '.doc']:
                 loader = Docx2txtLoader(file_path)
             elif file_extension == '.csv':
                 loader = CSVLoader(file_path, autodetect_encoding=True)
@@ -95,7 +95,7 @@ class FileExtractor:
                 loader = MarkdownLoader(file_path, autodetect_encoding=True)
             elif file_extension in ['.htm', '.html']:
                 loader = HTMLLoader(file_path)
-            elif file_extension == '.docx':
+            elif file_extension in ['.docx', '.doc']:
                 loader = Docx2txtLoader(file_path)
             elif file_extension == '.csv':
                 loader = CSVLoader(file_path, autodetect_encoding=True)

+ 1 - 3
api/core/data_loader/loader/unstructured/unstructured_msg.py

@@ -1,9 +1,7 @@
 import logging
-import re
-from typing import List, Optional, Tuple, cast
+from typing import List
 
 from langchain.document_loaders.base import BaseLoader
-from langchain.document_loaders.helpers import detect_file_encodings
 from langchain.schema import Document
 
 logger = logging.getLogger(__name__)

+ 1 - 5
api/core/data_loader/loader/unstructured/unstructured_ppt.py

@@ -1,14 +1,10 @@
 import logging
-import re
-from typing import List, Optional, Tuple, cast
-
+from typing import List
 from langchain.document_loaders.base import BaseLoader
-from langchain.document_loaders.helpers import detect_file_encodings
 from langchain.schema import Document
 
 logger = logging.getLogger(__name__)
 
-
 class UnstructuredPPTLoader(BaseLoader):
     """Load msg files.
 

+ 1 - 5
api/core/data_loader/loader/unstructured/unstructured_pptx.py

@@ -1,14 +1,10 @@
 import logging
-import re
-from typing import List, Optional, Tuple, cast
+from typing import List
 
 from langchain.document_loaders.base import BaseLoader
-from langchain.document_loaders.helpers import detect_file_encodings
 from langchain.schema import Document
 
 logger = logging.getLogger(__name__)
-
-
 class UnstructuredPPTXLoader(BaseLoader):
     """Load msg files.
 

+ 1 - 3
api/core/data_loader/loader/unstructured/unstructured_text.py

@@ -1,9 +1,7 @@
 import logging
-import re
-from typing import List, Optional, Tuple, cast
+from typing import List
 
 from langchain.document_loaders.base import BaseLoader
-from langchain.document_loaders.helpers import detect_file_encodings
 from langchain.schema import Document
 
 logger = logging.getLogger(__name__)

+ 1 - 3
api/core/data_loader/loader/unstructured/unstructured_xml.py

@@ -1,9 +1,7 @@
 import logging
-import re
-from typing import List, Optional, Tuple, cast
+from typing import List
 
 from langchain.document_loaders.base import BaseLoader
-from langchain.document_loaders.helpers import detect_file_encodings
 from langchain.schema import Document
 
 logger = logging.getLogger(__name__)

+ 5 - 16
api/services/file_service.py

@@ -15,9 +15,10 @@ from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
 from werkzeug.datastructures import FileStorage
 from werkzeug.exceptions import NotFound
 
-ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv',
-                      'jpg', 'jpeg', 'png', 'webp', 'gif', 'svg']
 IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg']
+ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'doc', 'csv'] + IMAGE_EXTENSIONS
+UNSTRUSTURED_ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx',
+                                      'docx', 'doc', 'csv', 'eml', 'msg', 'pptx', 'ppt', 'xml'] + IMAGE_EXTENSIONS
 PREVIEW_WORDS_LIMIT = 3000
 
 
@@ -27,13 +28,7 @@ class FileService:
     def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile:
         extension = file.filename.split('.')[-1]
         etl_type = current_app.config['ETL_TYPE']
-        if etl_type == 'Unstructured':
-            allowed_extensions = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx',
-                                  'docx', 'csv', 'eml', 'msg', 'pptx', 'ppt', 'xml',
-                                  'jpg', 'jpeg', 'png', 'webp', 'gif', 'svg']
-        else:
-            allowed_extensions = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv',
-                                  'jpg', 'jpeg', 'png', 'webp', 'gif', 'svg']
+        allowed_extensions = UNSTRUSTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS
         if extension.lower() not in allowed_extensions:
             raise UnsupportedFileTypeError()
         elif only_image and extension.lower() not in IMAGE_EXTENSIONS:
@@ -133,13 +128,7 @@ class FileService:
         # extract text from file
         extension = upload_file.extension
         etl_type = current_app.config['ETL_TYPE']
-        if etl_type == 'Unstructured':
-            allowed_extensions = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx',
-                                  'docx', 'csv', 'eml', 'msg', 'pptx', 'ppt', 'xml',
-                                  'jpg', 'jpeg', 'png', 'webp', 'gif', 'svg']
-        else:
-            allowed_extensions = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv',
-                                  'jpg', 'jpeg', 'png', 'webp', 'gif', 'svg']
+        allowed_extensions = UNSTRUSTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS
         if extension.lower() not in allowed_extensions:
             raise UnsupportedFileTypeError()