configuration.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. from copy import deepcopy
  2. from typing import Any
  3. from pydantic import BaseModel
  4. from core.helper import encrypter
  5. from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
  6. from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
  7. from core.tools.entities.tool_entities import (
  8. ToolParameter,
  9. ToolProviderCredentials,
  10. )
  11. from core.tools.provider.tool_provider import ToolProviderController
  12. from core.tools.tool.tool import Tool
  13. class ToolConfigurationManager(BaseModel):
  14. tenant_id: str
  15. provider_controller: ToolProviderController
  16. def _deep_copy(self, credentials: dict[str, str]) -> dict[str, str]:
  17. """
  18. deep copy credentials
  19. """
  20. return deepcopy(credentials)
  21. def encrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
  22. """
  23. encrypt tool credentials with tenant id
  24. return a deep copy of credentials with encrypted values
  25. """
  26. credentials = self._deep_copy(credentials)
  27. # get fields need to be decrypted
  28. fields = self.provider_controller.get_credentials_schema()
  29. for field_name, field in fields.items():
  30. if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
  31. if field_name in credentials:
  32. encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
  33. credentials[field_name] = encrypted
  34. return credentials
  35. def mask_tool_credentials(self, credentials: dict[str, Any]) -> dict[str, Any]:
  36. """
  37. mask tool credentials
  38. return a deep copy of credentials with masked values
  39. """
  40. credentials = self._deep_copy(credentials)
  41. # get fields need to be decrypted
  42. fields = self.provider_controller.get_credentials_schema()
  43. for field_name, field in fields.items():
  44. if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
  45. if field_name in credentials:
  46. if len(credentials[field_name]) > 6:
  47. credentials[field_name] = (
  48. credentials[field_name][:2]
  49. + "*" * (len(credentials[field_name]) - 4)
  50. + credentials[field_name][-2:]
  51. )
  52. else:
  53. credentials[field_name] = "*" * len(credentials[field_name])
  54. return credentials
  55. def decrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
  56. """
  57. decrypt tool credentials with tenant id
  58. return a deep copy of credentials with decrypted values
  59. """
  60. identity_id = ""
  61. if self.provider_controller.identity:
  62. identity_id = f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}"
  63. cache = ToolProviderCredentialsCache(
  64. tenant_id=self.tenant_id,
  65. identity_id=identity_id,
  66. cache_type=ToolProviderCredentialsCacheType.PROVIDER,
  67. )
  68. cached_credentials = cache.get()
  69. if cached_credentials:
  70. return cached_credentials
  71. credentials = self._deep_copy(credentials)
  72. # get fields need to be decrypted
  73. fields = self.provider_controller.get_credentials_schema()
  74. for field_name, field in fields.items():
  75. if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
  76. if field_name in credentials:
  77. try:
  78. credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
  79. except:
  80. pass
  81. cache.set(credentials)
  82. return credentials
  83. def delete_tool_credentials_cache(self):
  84. identity_id = ""
  85. if self.provider_controller.identity:
  86. identity_id = f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}"
  87. cache = ToolProviderCredentialsCache(
  88. tenant_id=self.tenant_id,
  89. identity_id=identity_id,
  90. cache_type=ToolProviderCredentialsCacheType.PROVIDER,
  91. )
  92. cache.delete()
  93. class ToolParameterConfigurationManager(BaseModel):
  94. """
  95. Tool parameter configuration manager
  96. """
  97. tenant_id: str
  98. tool_runtime: Tool
  99. provider_name: str
  100. provider_type: str
  101. identity_id: str
  102. def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
  103. """
  104. deep copy parameters
  105. """
  106. return deepcopy(parameters)
  107. def _merge_parameters(self) -> list[ToolParameter]:
  108. """
  109. merge parameters
  110. """
  111. # get tool parameters
  112. tool_parameters = self.tool_runtime.parameters or []
  113. # get tool runtime parameters
  114. runtime_parameters = self.tool_runtime.get_runtime_parameters()
  115. # override parameters
  116. current_parameters = tool_parameters.copy()
  117. for runtime_parameter in runtime_parameters:
  118. found = False
  119. for index, parameter in enumerate(current_parameters):
  120. if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
  121. current_parameters[index] = runtime_parameter
  122. found = True
  123. break
  124. if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
  125. current_parameters.append(runtime_parameter)
  126. return current_parameters
  127. def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
  128. """
  129. mask tool parameters
  130. return a deep copy of parameters with masked values
  131. """
  132. parameters = self._deep_copy(parameters)
  133. # override parameters
  134. current_parameters = self._merge_parameters()
  135. for parameter in current_parameters:
  136. if (
  137. parameter.form == ToolParameter.ToolParameterForm.FORM
  138. and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
  139. ):
  140. if parameter.name in parameters:
  141. if len(parameters[parameter.name]) > 6:
  142. parameters[parameter.name] = (
  143. parameters[parameter.name][:2]
  144. + "*" * (len(parameters[parameter.name]) - 4)
  145. + parameters[parameter.name][-2:]
  146. )
  147. else:
  148. parameters[parameter.name] = "*" * len(parameters[parameter.name])
  149. return parameters
  150. def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
  151. """
  152. encrypt tool parameters with tenant id
  153. return a deep copy of parameters with encrypted values
  154. """
  155. # override parameters
  156. current_parameters = self._merge_parameters()
  157. parameters = self._deep_copy(parameters)
  158. for parameter in current_parameters:
  159. if (
  160. parameter.form == ToolParameter.ToolParameterForm.FORM
  161. and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
  162. ):
  163. if parameter.name in parameters:
  164. encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
  165. parameters[parameter.name] = encrypted
  166. return parameters
  167. def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
  168. """
  169. decrypt tool parameters with tenant id
  170. return a deep copy of parameters with decrypted values
  171. """
  172. if self.tool_runtime is None or self.tool_runtime.identity is None:
  173. raise ValueError("tool_runtime is required")
  174. cache = ToolParameterCache(
  175. tenant_id=self.tenant_id,
  176. provider=f"{self.provider_type}.{self.provider_name}",
  177. tool_name=self.tool_runtime.identity.name,
  178. cache_type=ToolParameterCacheType.PARAMETER,
  179. identity_id=self.identity_id,
  180. )
  181. cached_parameters = cache.get()
  182. if cached_parameters:
  183. return cached_parameters
  184. # override parameters
  185. current_parameters = self._merge_parameters()
  186. has_secret_input = False
  187. for parameter in current_parameters:
  188. if (
  189. parameter.form == ToolParameter.ToolParameterForm.FORM
  190. and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
  191. ):
  192. if parameter.name in parameters:
  193. try:
  194. has_secret_input = True
  195. parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
  196. except:
  197. pass
  198. if has_secret_input:
  199. cache.set(parameters)
  200. return parameters
  201. def delete_tool_parameters_cache(self):
  202. if self.tool_runtime is None or self.tool_runtime.identity is None:
  203. raise ValueError("tool_runtime is required")
  204. cache = ToolParameterCache(
  205. tenant_id=self.tenant_id,
  206. provider=f"{self.provider_type}.{self.provider_name}",
  207. tool_name=self.tool_runtime.identity.name,
  208. cache_type=ToolParameterCacheType.PARAMETER,
  209. identity_id=self.identity_id,
  210. )
  211. cache.delete()