瀏覽代碼

refactor: update load_stream method to directly yield file chunks (#9806)

zhuhao 6 月之前
父節點
當前提交
5bf31e7a86

+ 3 - 6
api/extensions/storage/aliyun_oss_storage.py

@@ -36,12 +36,9 @@ class AliyunOssStorage(BaseStorage):
         return data
 
     def load_stream(self, filename: str) -> Generator:
-        def generate(filename: str = filename) -> Generator:
-            obj = self.client.get_object(self.__wrapper_folder_filename(filename))
-            while chunk := obj.read(4096):
-                yield chunk
-
-        return generate()
+        obj = self.client.get_object(self.__wrapper_folder_filename(filename))
+        while chunk := obj.read(4096):
+            yield chunk
 
     def download(self, filename, target_filepath):
         self.client.get_object_to_file(self.__wrapper_folder_filename(filename), target_filepath)

+ 8 - 11
api/extensions/storage/aws_s3_storage.py

@@ -62,17 +62,14 @@ class AwsS3Storage(BaseStorage):
         return data
 
     def load_stream(self, filename: str) -> Generator:
-        def generate(filename: str = filename) -> Generator:
-            try:
-                response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
-                yield from response["Body"].iter_chunks()
-            except ClientError as ex:
-                if ex.response["Error"]["Code"] == "NoSuchKey":
-                    raise FileNotFoundError("File not found")
-                else:
-                    raise
-
-        return generate()
+        try:
+            response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
+            yield from response["Body"].iter_chunks()
+        except ClientError as ex:
+            if ex.response["Error"]["Code"] == "NoSuchKey":
+                raise FileNotFoundError("File not found")
+            else:
+                raise
 
     def download(self, filename, target_filepath):
         self.client.download_file(self.bucket_name, filename, target_filepath)

+ 3 - 7
api/extensions/storage/azure_blob_storage.py

@@ -32,13 +32,9 @@ class AzureBlobStorage(BaseStorage):
 
     def load_stream(self, filename: str) -> Generator:
         client = self._sync_client()
-
-        def generate(filename: str = filename) -> Generator:
-            blob = client.get_blob_client(container=self.bucket_name, blob=filename)
-            blob_data = blob.download_blob()
-            yield from blob_data.chunks()
-
-        return generate(filename)
+        blob = client.get_blob_client(container=self.bucket_name, blob=filename)
+        blob_data = blob.download_blob()
+        yield from blob_data.chunks()
 
     def download(self, filename, target_filepath):
         client = self._sync_client()

+ 3 - 6
api/extensions/storage/baidu_obs_storage.py

@@ -39,12 +39,9 @@ class BaiduObsStorage(BaseStorage):
         return response.data.read()
 
     def load_stream(self, filename: str) -> Generator:
-        def generate(filename: str = filename) -> Generator:
-            response = self.client.get_object(bucket_name=self.bucket_name, key=filename).data
-            while chunk := response.read(4096):
-                yield chunk
-
-        return generate()
+        response = self.client.get_object(bucket_name=self.bucket_name, key=filename).data
+        while chunk := response.read(4096):
+            yield chunk
 
     def download(self, filename, target_filepath):
         self.client.get_object_to_file(bucket_name=self.bucket_name, key=filename, file_name=target_filepath)

+ 5 - 8
api/extensions/storage/google_cloud_storage.py

@@ -39,14 +39,11 @@ class GoogleCloudStorage(BaseStorage):
         return data
 
     def load_stream(self, filename: str) -> Generator:
-        def generate(filename: str = filename) -> Generator:
-            bucket = self.client.get_bucket(self.bucket_name)
-            blob = bucket.get_blob(filename)
-            with blob.open(mode="rb") as blob_stream:
-                while chunk := blob_stream.read(4096):
-                    yield chunk
-
-        return generate()
+        bucket = self.client.get_bucket(self.bucket_name)
+        blob = bucket.get_blob(filename)
+        with blob.open(mode="rb") as blob_stream:
+            while chunk := blob_stream.read(4096):
+                yield chunk
 
     def download(self, filename, target_filepath):
         bucket = self.client.get_bucket(self.bucket_name)

+ 3 - 6
api/extensions/storage/huawei_obs_storage.py

@@ -27,12 +27,9 @@ class HuaweiObsStorage(BaseStorage):
         return data
 
     def load_stream(self, filename: str) -> Generator:
-        def generate(filename: str = filename) -> Generator:
-            response = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response
-            while chunk := response.read(4096):
-                yield chunk
-
-        return generate()
+        response = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response
+        while chunk := response.read(4096):
+            yield chunk
 
     def download(self, filename, target_filepath):
         self.client.getObject(bucketName=self.bucket_name, objectKey=filename, downloadPath=target_filepath)

+ 5 - 9
api/extensions/storage/local_fs_storage.py

@@ -40,15 +40,11 @@ class LocalFsStorage(BaseStorage):
 
     def load_stream(self, filename: str) -> Generator:
         filepath = self._build_filepath(filename)
-
-        def generate() -> Generator:
-            if not os.path.exists(filepath):
-                raise FileNotFoundError("File not found")
-            with open(filepath, "rb") as f:
-                while chunk := f.read(4096):  # Read in chunks of 4KB
-                    yield chunk
-
-        return generate()
+        if not os.path.exists(filepath):
+            raise FileNotFoundError("File not found")
+        with open(filepath, "rb") as f:
+            while chunk := f.read(4096):  # Read in chunks of 4KB
+                yield chunk
 
     def download(self, filename, target_filepath):
         filepath = self._build_filepath(filename)

+ 8 - 11
api/extensions/storage/oracle_oci_storage.py

@@ -36,17 +36,14 @@ class OracleOCIStorage(BaseStorage):
         return data
 
     def load_stream(self, filename: str) -> Generator:
-        def generate(filename: str = filename) -> Generator:
-            try:
-                response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
-                yield from response["Body"].iter_chunks()
-            except ClientError as ex:
-                if ex.response["Error"]["Code"] == "NoSuchKey":
-                    raise FileNotFoundError("File not found")
-                else:
-                    raise
-
-        return generate()
+        try:
+            response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
+            yield from response["Body"].iter_chunks()
+        except ClientError as ex:
+            if ex.response["Error"]["Code"] == "NoSuchKey":
+                raise FileNotFoundError("File not found")
+            else:
+                raise
 
     def download(self, filename, target_filepath):
         self.client.download_file(self.bucket_name, filename, target_filepath)

+ 5 - 8
api/extensions/storage/supabase_storage.py

@@ -36,17 +36,14 @@ class SupabaseStorage(BaseStorage):
         return content
 
     def load_stream(self, filename: str) -> Generator:
-        def generate(filename: str = filename) -> Generator:
-            result = self.client.storage.from_(self.bucket_name).download(filename)
-            byte_stream = io.BytesIO(result)
-            while chunk := byte_stream.read(4096):  # Read in chunks of 4KB
-                yield chunk
-
-        return generate()
+        result = self.client.storage.from_(self.bucket_name).download(filename)
+        byte_stream = io.BytesIO(result)
+        while chunk := byte_stream.read(4096):  # Read in chunks of 4KB
+            yield chunk
 
     def download(self, filename, target_filepath):
         result = self.client.storage.from_(self.bucket_name).download(filename)
-        Path(result).write_bytes(result)
+        Path(target_filepath).write_bytes(result)
 
     def exists(self, filename):
         result = self.client.storage.from_(self.bucket_name).list(filename)

+ 2 - 5
api/extensions/storage/tencent_cos_storage.py

@@ -29,11 +29,8 @@ class TencentCosStorage(BaseStorage):
         return data
 
     def load_stream(self, filename: str) -> Generator:
-        def generate(filename: str = filename) -> Generator:
-            response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
-            yield from response["Body"].get_stream(chunk_size=4096)
-
-        return generate()
+        response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
+        yield from response["Body"].get_stream(chunk_size=4096)
 
     def download(self, filename, target_filepath):
         response = self.client.get_object(Bucket=self.bucket_name, Key=filename)

+ 3 - 6
api/extensions/storage/volcengine_tos_storage.py

@@ -27,12 +27,9 @@ class VolcengineTosStorage(BaseStorage):
         return data
 
     def load_stream(self, filename: str) -> Generator:
-        def generate(filename: str = filename) -> Generator:
-            response = self.client.get_object(bucket=self.bucket_name, key=filename)
-            while chunk := response.read(4096):
-                yield chunk
-
-        return generate()
+        response = self.client.get_object(bucket=self.bucket_name, key=filename)
+        while chunk := response.read(4096):
+            yield chunk
 
     def download(self, filename, target_filepath):
         self.client.get_object_to_file(bucket=self.bucket_name, key=filename, file_path=target_filepath)