file_manager.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import base64
  2. from configs import dify_config
  3. from core.file import file_repository
  4. from core.helper import ssrf_proxy
  5. from core.model_runtime.entities import (
  6. AudioPromptMessageContent,
  7. DocumentPromptMessageContent,
  8. ImagePromptMessageContent,
  9. VideoPromptMessageContent,
  10. )
  11. from extensions.ext_database import db
  12. from extensions.ext_storage import storage
  13. from . import helpers
  14. from .enums import FileAttribute
  15. from .models import File, FileTransferMethod, FileType
  16. from .tool_file_parser import ToolFileParser
  17. def get_attr(*, file: File, attr: FileAttribute):
  18. match attr:
  19. case FileAttribute.TYPE:
  20. return file.type.value
  21. case FileAttribute.SIZE:
  22. return file.size
  23. case FileAttribute.NAME:
  24. return file.filename
  25. case FileAttribute.MIME_TYPE:
  26. return file.mime_type
  27. case FileAttribute.TRANSFER_METHOD:
  28. return file.transfer_method.value
  29. case FileAttribute.URL:
  30. return file.remote_url
  31. case FileAttribute.EXTENSION:
  32. return file.extension
  33. def to_prompt_message_content(
  34. f: File,
  35. /,
  36. *,
  37. image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
  38. ):
  39. match f.type:
  40. case FileType.IMAGE:
  41. image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
  42. if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
  43. data = _to_url(f)
  44. else:
  45. data = _to_base64_data_string(f)
  46. return ImagePromptMessageContent(data=data, detail=image_detail_config, format=f.extension.lstrip("."))
  47. case FileType.AUDIO:
  48. data = _to_base64_data_string(f)
  49. if f.extension is None:
  50. raise ValueError("Missing file extension")
  51. return AudioPromptMessageContent(data=data, format=f.extension.lstrip("."))
  52. case FileType.VIDEO:
  53. if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url":
  54. data = _to_url(f)
  55. else:
  56. data = _to_base64_data_string(f)
  57. if f.extension is None:
  58. raise ValueError("Missing file extension")
  59. return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
  60. case FileType.DOCUMENT:
  61. data = _to_base64_data_string(f)
  62. return DocumentPromptMessageContent(encode_format="base64", data=data, format=f.extension.lstrip("."))
  63. case _:
  64. raise ValueError(f"file type {f.type} is not supported")
  65. def download(f: File, /):
  66. if f.transfer_method == FileTransferMethod.TOOL_FILE:
  67. tool_file = file_repository.get_tool_file(session=db.session(), file=f)
  68. return _download_file_content(tool_file.file_key)
  69. elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
  70. upload_file = file_repository.get_upload_file(session=db.session(), file=f)
  71. return _download_file_content(upload_file.key)
  72. # remote file
  73. response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
  74. response.raise_for_status()
  75. return response.content
  76. def _download_file_content(path: str, /):
  77. """
  78. Download and return the contents of a file as bytes.
  79. This function loads the file from storage and ensures it's in bytes format.
  80. Args:
  81. path (str): The path to the file in storage.
  82. Returns:
  83. bytes: The contents of the file as a bytes object.
  84. Raises:
  85. ValueError: If the loaded file is not a bytes object.
  86. """
  87. data = storage.load(path, stream=False)
  88. if not isinstance(data, bytes):
  89. raise ValueError(f"file {path} is not a bytes object")
  90. return data
  91. def _get_encoded_string(f: File, /):
  92. match f.transfer_method:
  93. case FileTransferMethod.REMOTE_URL:
  94. response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
  95. response.raise_for_status()
  96. data = response.content
  97. case FileTransferMethod.LOCAL_FILE:
  98. upload_file = file_repository.get_upload_file(session=db.session(), file=f)
  99. data = _download_file_content(upload_file.key)
  100. case FileTransferMethod.TOOL_FILE:
  101. tool_file = file_repository.get_tool_file(session=db.session(), file=f)
  102. data = _download_file_content(tool_file.file_key)
  103. encoded_string = base64.b64encode(data).decode("utf-8")
  104. return encoded_string
  105. def _to_base64_data_string(f: File, /):
  106. encoded_string = _get_encoded_string(f)
  107. return f"data:{f.mime_type};base64,{encoded_string}"
  108. def _to_url(f: File, /):
  109. if f.transfer_method == FileTransferMethod.REMOTE_URL:
  110. if f.remote_url is None:
  111. raise ValueError("Missing file remote_url")
  112. return f.remote_url
  113. elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
  114. if f.related_id is None:
  115. raise ValueError("Missing file related_id")
  116. return f.remote_url or helpers.get_signed_file_url(upload_file_id=f.related_id)
  117. elif f.transfer_method == FileTransferMethod.TOOL_FILE:
  118. # add sign url
  119. if f.related_id is None or f.extension is None:
  120. raise ValueError("Missing file related_id or extension")
  121. return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=f.related_id, extension=f.extension)
  122. else:
  123. raise ValueError(f"Unsupported transfer method: {f.transfer_method}")