configuration.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. import os
  2. from copy import deepcopy
  3. from typing import Any, Union
  4. from pydantic import BaseModel
  5. from yaml import FullLoader, load
  6. from core.helper import encrypter
  7. from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
  8. from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
  9. from core.tools.entities.tool_entities import (
  10. ModelToolConfiguration,
  11. ModelToolProviderConfiguration,
  12. ToolParameter,
  13. ToolProviderCredentials,
  14. )
  15. from core.tools.provider.tool_provider import ToolProviderController
  16. from core.tools.tool.tool import Tool
  17. class ToolConfigurationManager(BaseModel):
  18. tenant_id: str
  19. provider_controller: ToolProviderController
  20. def _deep_copy(self, credentials: dict[str, str]) -> dict[str, str]:
  21. """
  22. deep copy credentials
  23. """
  24. return deepcopy(credentials)
  25. def encrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
  26. """
  27. encrypt tool credentials with tenant id
  28. return a deep copy of credentials with encrypted values
  29. """
  30. credentials = self._deep_copy(credentials)
  31. # get fields need to be decrypted
  32. fields = self.provider_controller.get_credentials_schema()
  33. for field_name, field in fields.items():
  34. if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
  35. if field_name in credentials:
  36. encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
  37. credentials[field_name] = encrypted
  38. return credentials
  39. def mask_tool_credentials(self, credentials: dict[str, Any]) -> dict[str, Any]:
  40. """
  41. mask tool credentials
  42. return a deep copy of credentials with masked values
  43. """
  44. credentials = self._deep_copy(credentials)
  45. # get fields need to be decrypted
  46. fields = self.provider_controller.get_credentials_schema()
  47. for field_name, field in fields.items():
  48. if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
  49. if field_name in credentials:
  50. if len(credentials[field_name]) > 6:
  51. credentials[field_name] = \
  52. credentials[field_name][:2] + \
  53. '*' * (len(credentials[field_name]) - 4) +\
  54. credentials[field_name][-2:]
  55. else:
  56. credentials[field_name] = '*' * len(credentials[field_name])
  57. return credentials
  58. def decrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
  59. """
  60. decrypt tool credentials with tenant id
  61. return a deep copy of credentials with decrypted values
  62. """
  63. cache = ToolProviderCredentialsCache(
  64. tenant_id=self.tenant_id,
  65. identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
  66. cache_type=ToolProviderCredentialsCacheType.PROVIDER
  67. )
  68. cached_credentials = cache.get()
  69. if cached_credentials:
  70. return cached_credentials
  71. credentials = self._deep_copy(credentials)
  72. # get fields need to be decrypted
  73. fields = self.provider_controller.get_credentials_schema()
  74. for field_name, field in fields.items():
  75. if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
  76. if field_name in credentials:
  77. try:
  78. credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
  79. except:
  80. pass
  81. cache.set(credentials)
  82. return credentials
  83. def delete_tool_credentials_cache(self):
  84. cache = ToolProviderCredentialsCache(
  85. tenant_id=self.tenant_id,
  86. identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
  87. cache_type=ToolProviderCredentialsCacheType.PROVIDER
  88. )
  89. cache.delete()
  90. class ToolParameterConfigurationManager(BaseModel):
  91. """
  92. Tool parameter configuration manager
  93. """
  94. tenant_id: str
  95. tool_runtime: Tool
  96. provider_name: str
  97. provider_type: str
  98. identity_id: str
  99. def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
  100. """
  101. deep copy parameters
  102. """
  103. return deepcopy(parameters)
  104. def _merge_parameters(self) -> list[ToolParameter]:
  105. """
  106. merge parameters
  107. """
  108. # get tool parameters
  109. tool_parameters = self.tool_runtime.parameters or []
  110. # get tool runtime parameters
  111. runtime_parameters = self.tool_runtime.get_runtime_parameters() or []
  112. # override parameters
  113. current_parameters = tool_parameters.copy()
  114. for runtime_parameter in runtime_parameters:
  115. found = False
  116. for index, parameter in enumerate(current_parameters):
  117. if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
  118. current_parameters[index] = runtime_parameter
  119. found = True
  120. break
  121. if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
  122. current_parameters.append(runtime_parameter)
  123. return current_parameters
  124. def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
  125. """
  126. mask tool parameters
  127. return a deep copy of parameters with masked values
  128. """
  129. parameters = self._deep_copy(parameters)
  130. # override parameters
  131. current_parameters = self._merge_parameters()
  132. for parameter in current_parameters:
  133. if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
  134. if parameter.name in parameters:
  135. if len(parameters[parameter.name]) > 6:
  136. parameters[parameter.name] = \
  137. parameters[parameter.name][:2] + \
  138. '*' * (len(parameters[parameter.name]) - 4) +\
  139. parameters[parameter.name][-2:]
  140. else:
  141. parameters[parameter.name] = '*' * len(parameters[parameter.name])
  142. return parameters
  143. def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
  144. """
  145. encrypt tool parameters with tenant id
  146. return a deep copy of parameters with encrypted values
  147. """
  148. # override parameters
  149. current_parameters = self._merge_parameters()
  150. parameters = self._deep_copy(parameters)
  151. for parameter in current_parameters:
  152. if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
  153. if parameter.name in parameters:
  154. encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
  155. parameters[parameter.name] = encrypted
  156. return parameters
  157. def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
  158. """
  159. decrypt tool parameters with tenant id
  160. return a deep copy of parameters with decrypted values
  161. """
  162. cache = ToolParameterCache(
  163. tenant_id=self.tenant_id,
  164. provider=f'{self.provider_type}.{self.provider_name}',
  165. tool_name=self.tool_runtime.identity.name,
  166. cache_type=ToolParameterCacheType.PARAMETER,
  167. identity_id=self.identity_id
  168. )
  169. cached_parameters = cache.get()
  170. if cached_parameters:
  171. return cached_parameters
  172. # override parameters
  173. current_parameters = self._merge_parameters()
  174. has_secret_input = False
  175. for parameter in current_parameters:
  176. if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
  177. if parameter.name in parameters:
  178. try:
  179. has_secret_input = True
  180. parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
  181. except:
  182. pass
  183. if has_secret_input:
  184. cache.set(parameters)
  185. return parameters
  186. def delete_tool_parameters_cache(self):
  187. cache = ToolParameterCache(
  188. tenant_id=self.tenant_id,
  189. provider=f'{self.provider_type}.{self.provider_name}',
  190. tool_name=self.tool_runtime.identity.name,
  191. cache_type=ToolParameterCacheType.PARAMETER,
  192. identity_id=self.identity_id
  193. )
  194. cache.delete()
  195. class ModelToolConfigurationManager:
  196. """
  197. Model as tool configuration
  198. """
  199. _configurations: dict[str, ModelToolProviderConfiguration] = {}
  200. _model_configurations: dict[str, ModelToolConfiguration] = {}
  201. _inited = False
  202. @classmethod
  203. def _init_configuration(cls):
  204. """
  205. init configuration
  206. """
  207. absolute_path = os.path.abspath(os.path.dirname(__file__))
  208. model_tools_path = os.path.join(absolute_path, '..', 'model_tools')
  209. # get all .yaml file
  210. files = [f for f in os.listdir(model_tools_path) if f.endswith('.yaml')]
  211. for file in files:
  212. provider = file.split('.')[0]
  213. with open(os.path.join(model_tools_path, file), encoding='utf-8') as f:
  214. configurations = ModelToolProviderConfiguration(**load(f, Loader=FullLoader))
  215. models = configurations.models or []
  216. for model in models:
  217. model_key = f'{provider}.{model.model}'
  218. cls._model_configurations[model_key] = model
  219. cls._configurations[provider] = configurations
  220. cls._inited = True
  221. @classmethod
  222. def get_configuration(cls, provider: str) -> Union[ModelToolProviderConfiguration, None]:
  223. """
  224. get configuration by provider
  225. """
  226. if not cls._inited:
  227. cls._init_configuration()
  228. return cls._configurations.get(provider, None)
  229. @classmethod
  230. def get_all_configuration(cls) -> dict[str, ModelToolProviderConfiguration]:
  231. """
  232. get all configurations
  233. """
  234. if not cls._inited:
  235. cls._init_configuration()
  236. return cls._configurations
  237. @classmethod
  238. def get_model_configuration(cls, provider: str, model: str) -> Union[ModelToolConfiguration, None]:
  239. """
  240. get model configuration
  241. """
  242. key = f'{provider}.{model}'
  243. if not cls._inited:
  244. cls._init_configuration()
  245. return cls._model_configurations.get(key, None)