Procházet zdrojové kódy

feat(file-upload): add support for optional file source parameter (#9554)

-LAN- před 6 měsíci
rodič
revize
8d8a8fe295

+ 6 - 2
api/controllers/console/datasets/file.py

@@ -2,7 +2,7 @@ import urllib.parse
 
 from flask import request
 from flask_login import current_user
-from flask_restful import Resource, marshal_with
+from flask_restful import Resource, marshal_with, reqparse
 
 import services
 from configs import dify_config
@@ -48,6 +48,10 @@ class FileApi(Resource):
         # get file from request
         file = request.files["file"]
 
+        parser = reqparse.RequestParser()
+        parser.add_argument("source", type=str, required=False, location="args")
+        source = parser.parse_args().get("source")
+
         # check file
         if "file" not in request.files:
             raise NoFileUploadedError()
@@ -55,7 +59,7 @@ class FileApi(Resource):
         if len(request.files) > 1:
             raise TooManyFilesError()
         try:
-            upload_file = FileService.upload_file(file=file, user=current_user)
+            upload_file = FileService.upload_file(file=file, user=current_user, source=source)
         except services.errors.file.FileTooLargeError as file_too_large_error:
             raise FileTooLargeError(file_too_large_error.description)
         except services.errors.file.UnsupportedFileTypeError:

+ 6 - 2
api/controllers/web/file.py

@@ -1,7 +1,7 @@
 import urllib.parse
 
 from flask import request
-from flask_restful import marshal_with
+from flask_restful import marshal_with, reqparse
 
 import services
 from controllers.web import api
@@ -18,6 +18,10 @@ class FileApi(WebApiResource):
         # get file from request
         file = request.files["file"]
 
+        parser = reqparse.RequestParser()
+        parser.add_argument("source", type=str, required=False, location="args")
+        source = parser.parse_args().get("source")
+
         # check file
         if "file" not in request.files:
             raise NoFileUploadedError()
@@ -25,7 +29,7 @@ class FileApi(WebApiResource):
         if len(request.files) > 1:
             raise TooManyFilesError()
         try:
-            upload_file = FileService.upload_file(file, end_user)
+            upload_file = FileService.upload_file(file=file, user=end_user, source=source)
         except services.errors.file.FileTooLargeError as file_too_large_error:
             raise FileTooLargeError(file_too_large_error.description)
         except services.errors.file.UnsupportedFileTypeError:

+ 11 - 6
api/services/file_service.py

@@ -2,7 +2,7 @@ import datetime
 import hashlib
 import uuid
 from collections.abc import Generator
-from typing import Union
+from typing import Literal, Union
 
 from flask_login import current_user
 from werkzeug.datastructures import FileStorage
@@ -28,7 +28,9 @@ PREVIEW_WORDS_LIMIT = 3000
 
 class FileService:
     @staticmethod
-    def upload_file(file: FileStorage, user: Union[Account, EndUser]) -> UploadFile:
+    def upload_file(
+        file: FileStorage, user: Union[Account, EndUser], source: Literal["datasets"] | None = None
+    ) -> UploadFile:
         # get file name
         filename = file.filename
         if not filename:
@@ -36,11 +38,9 @@ class FileService:
         extension = filename.split(".")[-1]
         if len(filename) > 200:
             filename = filename.split(".")[0][:200] + "." + extension
-        # read file content
-        file_content = file.read()
 
-        # get file size
-        file_size = len(file_content)
+        if source == "datasets" and extension not in DOCUMENT_EXTENSIONS:
+            raise UnsupportedFileTypeError()
 
         # select file size limit
         if extension in IMAGE_EXTENSIONS:
@@ -52,6 +52,11 @@ class FileService:
         else:
             file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024
 
+        # read file content
+        file_content = file.read()
+        # get file size
+        file_size = len(file_content)
+
         # check if the file size is exceeded
         if file_size > file_size_limit:
             message = f"File size exceeded. {file_size} > {file_size_limit}"