tool_manager.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. from typing import List, Dict, Any, Tuple, Union
  2. from os import listdir, path
  3. from core.tools.entities.tool_entities import ToolInvokeMessage, ApiProviderAuthType, ToolProviderCredentials
  4. from core.tools.provider.tool_provider import ToolProviderController
  5. from core.tools.tool.builtin_tool import BuiltinTool
  6. from core.tools.tool.api_tool import ApiTool
  7. from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
  8. from core.tools.entities.constant import DEFAULT_PROVIDERS
  9. from core.tools.entities.common_entities import I18nObject
  10. from core.tools.errors import ToolProviderNotFoundError
  11. from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
  12. from core.tools.provider.app_tool_provider import AppBasedToolProviderEntity
  13. from core.tools.entities.user_entities import UserToolProvider
  14. from core.tools.utils.configration import ToolConfiguration
  15. from core.tools.utils.encoder import serialize_base_model_dict
  16. from core.tools.provider.builtin._positions import BuiltinToolProviderSort
  17. from core.model_runtime.entities.message_entities import PromptMessage
  18. from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
  19. from extensions.ext_database import db
  20. from models.tools import ApiToolProvider, BuiltinToolProvider
  21. import importlib
  22. import logging
  23. import json
  24. import mimetypes
  25. logger = logging.getLogger(__name__)
  26. _builtin_providers = {}
  27. class ToolManager:
  28. @staticmethod
  29. def invoke(
  30. provider: str,
  31. tool_id: str,
  32. tool_name: str,
  33. tool_parameters: Dict[str, Any],
  34. credentials: Dict[str, Any],
  35. prompt_messages: List[PromptMessage],
  36. ) -> List[ToolInvokeMessage]:
  37. """
  38. invoke the assistant
  39. :param provider: the name of the provider
  40. :param tool_id: the id of the tool
  41. :param tool_name: the name of the tool, defined in `get_tools`
  42. :param tool_parameters: the parameters of the tool
  43. :param credentials: the credentials of the tool
  44. :param prompt_messages: the prompt messages that the tool can use
  45. :return: the messages that the tool wants to send to the user
  46. """
  47. provider_entity: ToolProviderController = None
  48. if provider == DEFAULT_PROVIDERS.API_BASED:
  49. provider_entity = ApiBasedToolProviderController()
  50. elif provider == DEFAULT_PROVIDERS.APP_BASED:
  51. provider_entity = AppBasedToolProviderEntity()
  52. if provider_entity is None:
  53. # fetch the provider from .provider.builtin
  54. py_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.py')
  55. spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.{provider}', py_path)
  56. mod = importlib.util.module_from_spec(spec)
  57. spec.loader.exec_module(mod)
  58. # get all the classes in the module
  59. classes = [ x for _, x in vars(mod).items()
  60. if isinstance(x, type) and x != ToolProviderController and issubclass(x, ToolProviderController)
  61. ]
  62. if len(classes) == 0:
  63. raise ToolProviderNotFoundError(f'provider {provider} not found')
  64. if len(classes) > 1:
  65. raise ToolProviderNotFoundError(f'multiple providers found for {provider}')
  66. provider_entity = classes[0]()
  67. return provider_entity.invoke(tool_id, tool_name, tool_parameters, credentials, prompt_messages)
  68. @staticmethod
  69. def get_builtin_provider(provider: str) -> BuiltinToolProviderController:
  70. global _builtin_providers
  71. """
  72. get the builtin provider
  73. :param provider: the name of the provider
  74. :return: the provider
  75. """
  76. if len(_builtin_providers) == 0:
  77. # init the builtin providers
  78. ToolManager.list_builtin_providers()
  79. if provider not in _builtin_providers:
  80. raise ToolProviderNotFoundError(f'builtin provider {provider} not found')
  81. return _builtin_providers[provider]
  82. @staticmethod
  83. def get_builtin_tool(provider: str, tool_name: str) -> BuiltinTool:
  84. """
  85. get the builtin tool
  86. :param provider: the name of the provider
  87. :param tool_name: the name of the tool
  88. :return: the provider, the tool
  89. """
  90. provider_controller = ToolManager.get_builtin_provider(provider)
  91. tool = provider_controller.get_tool(tool_name)
  92. return tool
  93. @staticmethod
  94. def get_tool(provider_type: str, provider_id: str, tool_name: str, tanent_id: str = None) \
  95. -> Union[BuiltinTool, ApiTool]:
  96. """
  97. get the tool
  98. :param provider_type: the type of the provider
  99. :param provider_name: the name of the provider
  100. :param tool_name: the name of the tool
  101. :return: the tool
  102. """
  103. if provider_type == 'builtin':
  104. return ToolManager.get_builtin_tool(provider_id, tool_name)
  105. elif provider_type == 'api':
  106. if tanent_id is None:
  107. raise ValueError('tanent id is required for api provider')
  108. api_provider, _ = ToolManager.get_api_provider_controller(tanent_id, provider_id)
  109. return api_provider.get_tool(tool_name)
  110. elif provider_type == 'app':
  111. raise NotImplementedError('app provider not implemented')
  112. else:
  113. raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
  114. @staticmethod
  115. def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tanent_id,
  116. agent_callback: DifyAgentCallbackHandler = None) \
  117. -> Union[BuiltinTool, ApiTool]:
  118. """
  119. get the tool runtime
  120. :param provider_type: the type of the provider
  121. :param provider_name: the name of the provider
  122. :param tool_name: the name of the tool
  123. :return: the tool
  124. """
  125. if provider_type == 'builtin':
  126. builtin_tool = ToolManager.get_builtin_tool(provider_name, tool_name)
  127. # check if the builtin tool need credentials
  128. provider_controller = ToolManager.get_builtin_provider(provider_name)
  129. if not provider_controller.need_credentials:
  130. return builtin_tool.fork_tool_runtime(meta={
  131. 'tenant_id': tanent_id,
  132. 'credentials': {},
  133. }, agent_callback=agent_callback)
  134. # get credentials
  135. builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
  136. BuiltinToolProvider.tenant_id == tanent_id,
  137. BuiltinToolProvider.provider == provider_name,
  138. ).first()
  139. if builtin_provider is None:
  140. raise ToolProviderNotFoundError(f'builtin provider {provider_name} not found')
  141. # decrypt the credentials
  142. credentials = builtin_provider.credentials
  143. controller = ToolManager.get_builtin_provider(provider_name)
  144. tool_configuration = ToolConfiguration(tenant_id=tanent_id, provider_controller=controller)
  145. decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
  146. return builtin_tool.fork_tool_runtime(meta={
  147. 'tenant_id': tanent_id,
  148. 'credentials': decrypted_credentials,
  149. 'runtime_parameters': {}
  150. }, agent_callback=agent_callback)
  151. elif provider_type == 'api':
  152. if tanent_id is None:
  153. raise ValueError('tanent id is required for api provider')
  154. api_provider, credentials = ToolManager.get_api_provider_controller(tanent_id, provider_name)
  155. # decrypt the credentials
  156. tool_configuration = ToolConfiguration(tenant_id=tanent_id, provider_controller=api_provider)
  157. decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
  158. return api_provider.get_tool(tool_name).fork_tool_runtime(meta={
  159. 'tenant_id': tanent_id,
  160. 'credentials': decrypted_credentials,
  161. })
  162. elif provider_type == 'app':
  163. raise NotImplementedError('app provider not implemented')
  164. else:
  165. raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
  166. @staticmethod
  167. def get_builtin_provider_icon(provider: str) -> Tuple[str, str]:
  168. """
  169. get the absolute path of the icon of the builtin provider
  170. :param provider: the name of the provider
  171. :return: the absolute path of the icon, the mime type of the icon
  172. """
  173. # get provider
  174. provider_controller = ToolManager.get_builtin_provider(provider)
  175. absolute_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, '_assets', provider_controller.identity.icon)
  176. # check if the icon exists
  177. if not path.exists(absolute_path):
  178. raise ToolProviderNotFoundError(f'builtin provider {provider} icon not found')
  179. # get the mime type
  180. mime_type, _ = mimetypes.guess_type(absolute_path)
  181. mime_type = mime_type or 'application/octet-stream'
  182. return absolute_path, mime_type
  183. @staticmethod
  184. def list_builtin_providers() -> List[BuiltinToolProviderController]:
  185. global _builtin_providers
  186. # use cache first
  187. if len(_builtin_providers) > 0:
  188. return list(_builtin_providers.values())
  189. builtin_providers = []
  190. for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
  191. if provider.startswith('__'):
  192. continue
  193. if path.isdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider)):
  194. if provider.startswith('__'):
  195. continue
  196. py_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, f'{provider}.py')
  197. spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.{provider}', py_path)
  198. mod = importlib.util.module_from_spec(spec)
  199. spec.loader.exec_module(mod)
  200. # load all classes
  201. classes = [
  202. obj for name, obj in vars(mod).items()
  203. if isinstance(obj, type) and obj != BuiltinToolProviderController and issubclass(obj, BuiltinToolProviderController)
  204. ]
  205. if len(classes) == 0:
  206. raise ToolProviderNotFoundError(f'provider {provider} not found')
  207. if len(classes) > 1:
  208. raise ToolProviderNotFoundError(f'multiple providers found for {provider}')
  209. # init provider
  210. provider_class = classes[0]
  211. builtin_providers.append(provider_class())
  212. # cache the builtin providers
  213. for provider in builtin_providers:
  214. _builtin_providers[provider.identity.name] = provider
  215. return builtin_providers
  216. @staticmethod
  217. def user_list_providers(
  218. user_id: str,
  219. tenant_id: str,
  220. ) -> List[UserToolProvider]:
  221. result_providers: Dict[str, UserToolProvider] = {}
  222. # get builtin providers
  223. builtin_providers = ToolManager.list_builtin_providers()
  224. # append builtin providers
  225. for provider in builtin_providers:
  226. result_providers[provider.identity.name] = UserToolProvider(
  227. id=provider.identity.name,
  228. author=provider.identity.author,
  229. name=provider.identity.name,
  230. description=I18nObject(
  231. en_US=provider.identity.description.en_US,
  232. zh_Hans=provider.identity.description.zh_Hans,
  233. ),
  234. icon=provider.identity.icon,
  235. label=I18nObject(
  236. en_US=provider.identity.label.en_US,
  237. zh_Hans=provider.identity.label.zh_Hans,
  238. ),
  239. type=UserToolProvider.ProviderType.BUILTIN,
  240. team_credentials={},
  241. is_team_authorization=False,
  242. )
  243. # get credentials schema
  244. schema = provider.get_credentials_schema()
  245. for name, value in schema.items():
  246. result_providers[provider.identity.name].team_credentials[name] = \
  247. ToolProviderCredentials.CredentialsType.defaut(value.type)
  248. # check if the provider need credentials
  249. if not provider.need_credentials:
  250. result_providers[provider.identity.name].is_team_authorization = True
  251. result_providers[provider.identity.name].allow_delete = False
  252. # get db builtin providers
  253. db_builtin_providers: List[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \
  254. filter(BuiltinToolProvider.tenant_id == tenant_id).all()
  255. for db_builtin_provider in db_builtin_providers:
  256. # add provider into providers
  257. credentials = db_builtin_provider.credentials
  258. provider_name = db_builtin_provider.provider
  259. result_providers[provider_name].is_team_authorization = True
  260. # package builtin tool provider controller
  261. controller = ToolManager.get_builtin_provider(provider_name)
  262. # init tool configuration
  263. tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
  264. # decrypt the credentials and mask the credentials
  265. decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
  266. masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials)
  267. result_providers[provider_name].team_credentials = masked_credentials
  268. # get db api providers
  269. db_api_providers: List[ApiToolProvider] = db.session.query(ApiToolProvider). \
  270. filter(ApiToolProvider.tenant_id == tenant_id).all()
  271. for db_api_provider in db_api_providers:
  272. username = 'Anonymous'
  273. try:
  274. username = db_api_provider.user.name
  275. except Exception as e:
  276. logger.error(f'failed to get user name for api provider {db_api_provider.id}: {str(e)}')
  277. # add provider into providers
  278. credentials = db_api_provider.credentials
  279. provider_name = db_api_provider.name
  280. result_providers[provider_name] = UserToolProvider(
  281. id=db_api_provider.id,
  282. author=username,
  283. name=db_api_provider.name,
  284. description=I18nObject(
  285. en_US=db_api_provider.description,
  286. zh_Hans=db_api_provider.description,
  287. ),
  288. icon=db_api_provider.icon,
  289. label=I18nObject(
  290. en_US=db_api_provider.name,
  291. zh_Hans=db_api_provider.name,
  292. ),
  293. type=UserToolProvider.ProviderType.API,
  294. team_credentials={},
  295. is_team_authorization=True,
  296. )
  297. # package tool provider controller
  298. controller = ApiBasedToolProviderController.from_db(
  299. db_provider=db_api_provider,
  300. auth_type=ApiProviderAuthType.API_KEY if db_api_provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
  301. )
  302. # init tool configuration
  303. tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
  304. # decrypt the credentials and mask the credentials
  305. decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
  306. masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials)
  307. result_providers[provider_name].team_credentials = masked_credentials
  308. return BuiltinToolProviderSort.sort(list(result_providers.values()))
  309. @staticmethod
  310. def get_api_provider_controller(tanent_id: str, provider_id: str) -> Tuple[ApiBasedToolProviderController, Dict[str, Any]]:
  311. """
  312. get the api provider
  313. :param provider_name: the name of the provider
  314. :return: the provider controller, the credentials
  315. """
  316. provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
  317. ApiToolProvider.id == provider_id,
  318. ApiToolProvider.tenant_id == tanent_id,
  319. ).first()
  320. if provider is None:
  321. raise ToolProviderNotFoundError(f'api provider {provider_id} not found')
  322. controller = ApiBasedToolProviderController.from_db(
  323. provider, ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
  324. )
  325. controller.load_bundled_tools(provider.tools)
  326. return controller, provider.credentials
  327. @staticmethod
  328. def user_get_api_provider(provider: str, tenant_id: str) -> dict:
  329. """
  330. get api provider
  331. """
  332. """
  333. get tool provider
  334. """
  335. provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
  336. ApiToolProvider.tenant_id == tenant_id,
  337. ApiToolProvider.name == provider,
  338. ).first()
  339. if provider is None:
  340. raise ValueError(f'yout have not added provider {provider}')
  341. try:
  342. credentials = json.loads(provider.credentials_str) or {}
  343. except:
  344. credentials = {}
  345. # package tool provider controller
  346. controller = ApiBasedToolProviderController.from_db(
  347. provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
  348. )
  349. # init tool configuration
  350. tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
  351. decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
  352. masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
  353. try:
  354. icon = json.loads(provider.icon)
  355. except:
  356. icon = {
  357. "background": "#252525",
  358. "content": "\ud83d\ude01"
  359. }
  360. return json.loads(serialize_base_model_dict({
  361. 'schema_type': provider.schema_type,
  362. 'schema': provider.schema,
  363. 'tools': provider.tools,
  364. 'icon': icon,
  365. 'description': provider.description,
  366. 'credentials': masked_credentials,
  367. 'privacy_policy': provider.privacy_policy
  368. }))