tool_file_manager.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. import logging
  2. import time
  3. import os
  4. import hmac
  5. import base64
  6. import hashlib
  7. from typing import Union, Tuple, Generator
  8. from uuid import uuid4
  9. from mimetypes import guess_extension, guess_type
  10. from httpx import get
  11. from flask import current_app
  12. from models.tools import ToolFile
  13. from models.model import MessageFile
  14. from extensions.ext_database import db
  15. from extensions.ext_storage import storage
  16. logger = logging.getLogger(__name__)
  17. class ToolFileManager:
  18. @staticmethod
  19. def sign_file(file_id: str, extension: str) -> str:
  20. """
  21. sign file to get a temporary url
  22. """
  23. base_url = current_app.config.get('FILES_URL')
  24. file_preview_url = f'{base_url}/files/tools/{file_id}{extension}'
  25. timestamp = str(int(time.time()))
  26. nonce = os.urandom(16).hex()
  27. data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
  28. secret_key = current_app.config['SECRET_KEY'].encode()
  29. sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
  30. encoded_sign = base64.urlsafe_b64encode(sign).decode()
  31. return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
  32. @staticmethod
  33. def verify_file(file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
  34. """
  35. verify signature
  36. """
  37. data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
  38. secret_key = current_app.config['SECRET_KEY'].encode()
  39. recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
  40. recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
  41. # verify signature
  42. if sign != recalculated_encoded_sign:
  43. return False
  44. current_time = int(time.time())
  45. return current_time - int(timestamp) <= 300 # expired after 5 minutes
  46. @staticmethod
  47. def create_file_by_raw(user_id: str, tenant_id: str,
  48. conversation_id: str, file_binary: bytes,
  49. mimetype: str
  50. ) -> ToolFile:
  51. """
  52. create file
  53. """
  54. extension = guess_extension(mimetype) or '.bin'
  55. unique_name = uuid4().hex
  56. filename = f"/tools/{tenant_id}/{unique_name}{extension}"
  57. storage.save(filename, file_binary)
  58. tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id,
  59. conversation_id=conversation_id, file_key=filename, mimetype=mimetype)
  60. db.session.add(tool_file)
  61. db.session.commit()
  62. return tool_file
  63. @staticmethod
  64. def create_file_by_url(user_id: str, tenant_id: str,
  65. conversation_id: str, file_url: str,
  66. ) -> ToolFile:
  67. """
  68. create file
  69. """
  70. # try to download image
  71. response = get(file_url)
  72. response.raise_for_status()
  73. blob = response.content
  74. mimetype = guess_type(file_url)[0] or 'octet/stream'
  75. extension = guess_extension(mimetype) or '.bin'
  76. unique_name = uuid4().hex
  77. filename = f"/tools/{tenant_id}/{unique_name}{extension}"
  78. storage.save(filename, blob)
  79. tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id,
  80. conversation_id=conversation_id, file_key=filename,
  81. mimetype=mimetype, original_url=file_url)
  82. db.session.add(tool_file)
  83. db.session.commit()
  84. return tool_file
  85. @staticmethod
  86. def create_file_by_key(user_id: str, tenant_id: str,
  87. conversation_id: str, file_key: str,
  88. mimetype: str
  89. ) -> ToolFile:
  90. """
  91. create file
  92. """
  93. tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id,
  94. conversation_id=conversation_id, file_key=file_key, mimetype=mimetype)
  95. return tool_file
  96. @staticmethod
  97. def get_file_binary(id: str) -> Union[Tuple[bytes, str], None]:
  98. """
  99. get file binary
  100. :param id: the id of the file
  101. :return: the binary of the file, mime type
  102. """
  103. tool_file: ToolFile = db.session.query(ToolFile).filter(
  104. ToolFile.id == id,
  105. ).first()
  106. if not tool_file:
  107. return None
  108. blob = storage.load_once(tool_file.file_key)
  109. return blob, tool_file.mimetype
  110. @staticmethod
  111. def get_file_binary_by_message_file_id(id: str) -> Union[Tuple[bytes, str], None]:
  112. """
  113. get file binary
  114. :param id: the id of the file
  115. :return: the binary of the file, mime type
  116. """
  117. message_file: MessageFile = db.session.query(MessageFile).filter(
  118. MessageFile.id == id,
  119. ).first()
  120. # get tool file id
  121. tool_file_id = message_file.url.split('/')[-1]
  122. # trim extension
  123. tool_file_id = tool_file_id.split('.')[0]
  124. tool_file: ToolFile = db.session.query(ToolFile).filter(
  125. ToolFile.id == tool_file_id,
  126. ).first()
  127. if not tool_file:
  128. return None
  129. blob = storage.load_once(tool_file.file_key)
  130. return blob, tool_file.mimetype
  131. @staticmethod
  132. def get_file_generator_by_message_file_id(id: str) -> Union[Tuple[Generator, str], None]:
  133. """
  134. get file binary
  135. :param id: the id of the file
  136. :return: the binary of the file, mime type
  137. """
  138. message_file: MessageFile = db.session.query(MessageFile).filter(
  139. MessageFile.id == id,
  140. ).first()
  141. # get tool file id
  142. tool_file_id = message_file.url.split('/')[-1]
  143. # trim extension
  144. tool_file_id = tool_file_id.split('.')[0]
  145. tool_file: ToolFile = db.session.query(ToolFile).filter(
  146. ToolFile.id == tool_file_id,
  147. ).first()
  148. if not tool_file:
  149. return None
  150. generator = storage.load_stream(tool_file.file_key)
  151. return generator, tool_file.mimetype
  152. # init tool_file_parser
  153. from core.file.tool_file_parser import tool_file_manager
  154. tool_file_manager['manager'] = ToolFileManager