azure_blob_storage.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. from collections.abc import Generator
  2. from datetime import UTC, datetime, timedelta
  3. from typing import Optional
  4. from azure.identity import ChainedTokenCredential, DefaultAzureCredential
  5. from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas
  6. from configs import dify_config
  7. from extensions.ext_redis import redis_client
  8. from extensions.storage.base_storage import BaseStorage
  9. class AzureBlobStorage(BaseStorage):
  10. """Implementation for Azure Blob storage."""
  11. def __init__(self):
  12. super().__init__()
  13. self.bucket_name = dify_config.AZURE_BLOB_CONTAINER_NAME
  14. self.account_url = dify_config.AZURE_BLOB_ACCOUNT_URL
  15. self.account_name = dify_config.AZURE_BLOB_ACCOUNT_NAME
  16. self.account_key = dify_config.AZURE_BLOB_ACCOUNT_KEY
  17. self.credential: Optional[ChainedTokenCredential] = None
  18. if self.account_key == "managedidentity":
  19. self.credential = DefaultAzureCredential()
  20. else:
  21. self.credential = None
  22. def save(self, filename, data):
  23. client = self._sync_client()
  24. blob_container = client.get_container_client(container=self.bucket_name)
  25. blob_container.upload_blob(filename, data)
  26. def load_once(self, filename: str) -> bytes:
  27. client = self._sync_client()
  28. blob = client.get_container_client(container=self.bucket_name)
  29. blob = blob.get_blob_client(blob=filename)
  30. data: bytes = blob.download_blob().readall()
  31. return data
  32. def load_stream(self, filename: str) -> Generator:
  33. client = self._sync_client()
  34. blob = client.get_blob_client(container=self.bucket_name, blob=filename)
  35. blob_data = blob.download_blob()
  36. yield from blob_data.chunks()
  37. def download(self, filename, target_filepath):
  38. client = self._sync_client()
  39. blob = client.get_blob_client(container=self.bucket_name, blob=filename)
  40. with open(target_filepath, "wb") as my_blob:
  41. blob_data = blob.download_blob()
  42. blob_data.readinto(my_blob)
  43. def exists(self, filename):
  44. client = self._sync_client()
  45. blob = client.get_blob_client(container=self.bucket_name, blob=filename)
  46. return blob.exists()
  47. def delete(self, filename):
  48. client = self._sync_client()
  49. blob_container = client.get_container_client(container=self.bucket_name)
  50. blob_container.delete_blob(filename)
  51. def _sync_client(self):
  52. if self.account_key == "managedidentity":
  53. return BlobServiceClient(account_url=self.account_url, credential=self.credential) # type: ignore
  54. cache_key = "azure_blob_sas_token_{}_{}".format(self.account_name, self.account_key)
  55. cache_result = redis_client.get(cache_key)
  56. if cache_result is not None:
  57. sas_token = cache_result.decode("utf-8")
  58. else:
  59. sas_token = generate_account_sas(
  60. account_name=self.account_name or "",
  61. account_key=self.account_key or "",
  62. resource_types=ResourceTypes(service=True, container=True, object=True),
  63. permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True),
  64. expiry=datetime.now(UTC).replace(tzinfo=None) + timedelta(hours=1),
  65. )
  66. redis_client.set(cache_key, sas_token, ex=3000)
  67. return BlobServiceClient(account_url=self.account_url or "", credential=sas_token)