tool_manager.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547
  1. import importlib
  2. import json
  3. import logging
  4. import mimetypes
  5. from os import listdir, path
  6. from typing import Any, Union
  7. from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
  8. from core.model_runtime.entities.message_entities import PromptMessage
  9. from core.provider_manager import ProviderManager
  10. from core.tools.entities.common_entities import I18nObject
  11. from core.tools.entities.constant import DEFAULT_PROVIDERS
  12. from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeMessage, ToolProviderCredentials
  13. from core.tools.entities.user_entities import UserToolProvider
  14. from core.tools.errors import ToolProviderNotFoundError
  15. from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
  16. from core.tools.provider.app_tool_provider import AppBasedToolProviderEntity
  17. from core.tools.provider.builtin._positions import BuiltinToolProviderSort
  18. from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
  19. from core.tools.provider.model_tool_provider import ModelToolProviderController
  20. from core.tools.provider.tool_provider import ToolProviderController
  21. from core.tools.tool.api_tool import ApiTool
  22. from core.tools.tool.builtin_tool import BuiltinTool
  23. from core.tools.utils.configuration import ModelToolConfigurationManager, ToolConfiguration
  24. from core.tools.utils.encoder import serialize_base_model_dict
  25. from extensions.ext_database import db
  26. from models.tools import ApiToolProvider, BuiltinToolProvider
  27. logger = logging.getLogger(__name__)
  28. _builtin_providers = {}
  29. _builtin_tools_labels = {}
  30. class ToolManager:
  31. @staticmethod
  32. def invoke(
  33. provider: str,
  34. tool_id: str,
  35. tool_name: str,
  36. tool_parameters: dict[str, Any],
  37. credentials: dict[str, Any],
  38. prompt_messages: list[PromptMessage],
  39. ) -> list[ToolInvokeMessage]:
  40. """
  41. invoke the assistant
  42. :param provider: the name of the provider
  43. :param tool_id: the id of the tool
  44. :param tool_name: the name of the tool, defined in `get_tools`
  45. :param tool_parameters: the parameters of the tool
  46. :param credentials: the credentials of the tool
  47. :param prompt_messages: the prompt messages that the tool can use
  48. :return: the messages that the tool wants to send to the user
  49. """
  50. provider_entity: ToolProviderController = None
  51. if provider == DEFAULT_PROVIDERS.API_BASED:
  52. provider_entity = ApiBasedToolProviderController()
  53. elif provider == DEFAULT_PROVIDERS.APP_BASED:
  54. provider_entity = AppBasedToolProviderEntity()
  55. if provider_entity is None:
  56. # fetch the provider from .provider.builtin
  57. py_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.py')
  58. spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.{provider}', py_path)
  59. mod = importlib.util.module_from_spec(spec)
  60. spec.loader.exec_module(mod)
  61. # get all the classes in the module
  62. classes = [ x for _, x in vars(mod).items()
  63. if isinstance(x, type) and x != ToolProviderController and issubclass(x, ToolProviderController)
  64. ]
  65. if len(classes) == 0:
  66. raise ToolProviderNotFoundError(f'provider {provider} not found')
  67. if len(classes) > 1:
  68. raise ToolProviderNotFoundError(f'multiple providers found for {provider}')
  69. provider_entity = classes[0]()
  70. return provider_entity.invoke(tool_id, tool_name, tool_parameters, credentials, prompt_messages)
  71. @staticmethod
  72. def get_builtin_provider(provider: str) -> BuiltinToolProviderController:
  73. global _builtin_providers
  74. """
  75. get the builtin provider
  76. :param provider: the name of the provider
  77. :return: the provider
  78. """
  79. if len(_builtin_providers) == 0:
  80. # init the builtin providers
  81. ToolManager.list_builtin_providers()
  82. if provider not in _builtin_providers:
  83. raise ToolProviderNotFoundError(f'builtin provider {provider} not found')
  84. return _builtin_providers[provider]
  85. @staticmethod
  86. def get_builtin_tool(provider: str, tool_name: str) -> BuiltinTool:
  87. """
  88. get the builtin tool
  89. :param provider: the name of the provider
  90. :param tool_name: the name of the tool
  91. :return: the provider, the tool
  92. """
  93. provider_controller = ToolManager.get_builtin_provider(provider)
  94. tool = provider_controller.get_tool(tool_name)
  95. return tool
  96. @staticmethod
  97. def get_tool(provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None) \
  98. -> Union[BuiltinTool, ApiTool]:
  99. """
  100. get the tool
  101. :param provider_type: the type of the provider
  102. :param provider_name: the name of the provider
  103. :param tool_name: the name of the tool
  104. :return: the tool
  105. """
  106. if provider_type == 'builtin':
  107. return ToolManager.get_builtin_tool(provider_id, tool_name)
  108. elif provider_type == 'api':
  109. if tenant_id is None:
  110. raise ValueError('tenant id is required for api provider')
  111. api_provider, _ = ToolManager.get_api_provider_controller(tenant_id, provider_id)
  112. return api_provider.get_tool(tool_name)
  113. elif provider_type == 'app':
  114. raise NotImplementedError('app provider not implemented')
  115. else:
  116. raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
  117. @staticmethod
  118. def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id: str,
  119. agent_callback: DifyAgentCallbackHandler = None) \
  120. -> Union[BuiltinTool, ApiTool]:
  121. """
  122. get the tool runtime
  123. :param provider_type: the type of the provider
  124. :param provider_name: the name of the provider
  125. :param tool_name: the name of the tool
  126. :return: the tool
  127. """
  128. if provider_type == 'builtin':
  129. builtin_tool = ToolManager.get_builtin_tool(provider_name, tool_name)
  130. # check if the builtin tool need credentials
  131. provider_controller = ToolManager.get_builtin_provider(provider_name)
  132. if not provider_controller.need_credentials:
  133. return builtin_tool.fork_tool_runtime(meta={
  134. 'tenant_id': tenant_id,
  135. 'credentials': {},
  136. }, agent_callback=agent_callback)
  137. # get credentials
  138. builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
  139. BuiltinToolProvider.tenant_id == tenant_id,
  140. BuiltinToolProvider.provider == provider_name,
  141. ).first()
  142. if builtin_provider is None:
  143. raise ToolProviderNotFoundError(f'builtin provider {provider_name} not found')
  144. # decrypt the credentials
  145. credentials = builtin_provider.credentials
  146. controller = ToolManager.get_builtin_provider(provider_name)
  147. tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
  148. decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
  149. return builtin_tool.fork_tool_runtime(meta={
  150. 'tenant_id': tenant_id,
  151. 'credentials': decrypted_credentials,
  152. 'runtime_parameters': {}
  153. }, agent_callback=agent_callback)
  154. elif provider_type == 'api':
  155. if tenant_id is None:
  156. raise ValueError('tenant id is required for api provider')
  157. api_provider, credentials = ToolManager.get_api_provider_controller(tenant_id, provider_name)
  158. # decrypt the credentials
  159. tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=api_provider)
  160. decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
  161. return api_provider.get_tool(tool_name).fork_tool_runtime(meta={
  162. 'tenant_id': tenant_id,
  163. 'credentials': decrypted_credentials,
  164. })
  165. elif provider_type == 'model':
  166. if tenant_id is None:
  167. raise ValueError('tenant id is required for model provider')
  168. # get model provider
  169. model_provider = ToolManager.get_model_provider(tenant_id, provider_name)
  170. # get tool
  171. model_tool = model_provider.get_tool(tool_name)
  172. return model_tool.fork_tool_runtime(meta={
  173. 'tenant_id': tenant_id,
  174. 'credentials': model_tool.model_configuration['model_instance'].credentials
  175. })
  176. elif provider_type == 'app':
  177. raise NotImplementedError('app provider not implemented')
  178. else:
  179. raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
  180. @staticmethod
  181. def get_builtin_provider_icon(provider: str) -> tuple[str, str]:
  182. """
  183. get the absolute path of the icon of the builtin provider
  184. :param provider: the name of the provider
  185. :return: the absolute path of the icon, the mime type of the icon
  186. """
  187. # get provider
  188. provider_controller = ToolManager.get_builtin_provider(provider)
  189. absolute_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, '_assets', provider_controller.identity.icon)
  190. # check if the icon exists
  191. if not path.exists(absolute_path):
  192. raise ToolProviderNotFoundError(f'builtin provider {provider} icon not found')
  193. # get the mime type
  194. mime_type, _ = mimetypes.guess_type(absolute_path)
  195. mime_type = mime_type or 'application/octet-stream'
  196. return absolute_path, mime_type
  197. @staticmethod
  198. def list_builtin_providers() -> list[BuiltinToolProviderController]:
  199. global _builtin_providers
  200. # use cache first
  201. if len(_builtin_providers) > 0:
  202. return list(_builtin_providers.values())
  203. builtin_providers: list[BuiltinToolProviderController] = []
  204. for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
  205. if provider.startswith('__'):
  206. continue
  207. if path.isdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider)):
  208. if provider.startswith('__'):
  209. continue
  210. py_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, f'{provider}.py')
  211. spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.{provider}', py_path)
  212. mod = importlib.util.module_from_spec(spec)
  213. spec.loader.exec_module(mod)
  214. # load all classes
  215. classes = [
  216. obj for name, obj in vars(mod).items()
  217. if isinstance(obj, type) and obj != BuiltinToolProviderController and issubclass(obj, BuiltinToolProviderController)
  218. ]
  219. if len(classes) == 0:
  220. raise ToolProviderNotFoundError(f'provider {provider} not found')
  221. if len(classes) > 1:
  222. raise ToolProviderNotFoundError(f'multiple providers found for {provider}')
  223. # init provider
  224. provider_class = classes[0]
  225. builtin_providers.append(provider_class())
  226. # cache the builtin providers
  227. for provider in builtin_providers:
  228. _builtin_providers[provider.identity.name] = provider
  229. for tool in provider.get_tools():
  230. _builtin_tools_labels[tool.identity.name] = tool.identity.label
  231. return builtin_providers
  232. @staticmethod
  233. def list_model_providers(tenant_id: str = None) -> list[ModelToolProviderController]:
  234. """
  235. list all the model providers
  236. :return: the list of the model providers
  237. """
  238. tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff'
  239. # get configurations
  240. model_configurations = ModelToolConfigurationManager.get_all_configuration()
  241. # get all providers
  242. provider_manager = ProviderManager()
  243. configurations = provider_manager.get_configurations(tenant_id).values()
  244. # get model providers
  245. model_providers: list[ModelToolProviderController] = []
  246. for configuration in configurations:
  247. # all the model tool should be configurated
  248. if configuration.provider.provider not in model_configurations:
  249. continue
  250. if not ModelToolProviderController.is_configuration_valid(configuration):
  251. continue
  252. model_providers.append(ModelToolProviderController.from_db(configuration))
  253. return model_providers
  254. @staticmethod
  255. def get_model_provider(tenant_id: str, provider_name: str) -> ModelToolProviderController:
  256. """
  257. get the model provider
  258. :param provider_name: the name of the provider
  259. :return: the provider
  260. """
  261. # get configurations
  262. provider_manager = ProviderManager()
  263. configurations = provider_manager.get_configurations(tenant_id)
  264. configuration = configurations.get(provider_name)
  265. if configuration is None:
  266. raise ToolProviderNotFoundError(f'model provider {provider_name} not found')
  267. return ModelToolProviderController.from_db(configuration)
  268. @staticmethod
  269. def get_tool_label(tool_name: str) -> Union[I18nObject, None]:
  270. """
  271. get the tool label
  272. :param tool_name: the name of the tool
  273. :return: the label of the tool
  274. """
  275. global _builtin_tools_labels
  276. if len(_builtin_tools_labels) == 0:
  277. # init the builtin providers
  278. ToolManager.list_builtin_providers()
  279. if tool_name not in _builtin_tools_labels:
  280. return None
  281. return _builtin_tools_labels[tool_name]
  282. @staticmethod
  283. def user_list_providers(
  284. user_id: str,
  285. tenant_id: str,
  286. ) -> list[UserToolProvider]:
  287. result_providers: dict[str, UserToolProvider] = {}
  288. # get builtin providers
  289. builtin_providers = ToolManager.list_builtin_providers()
  290. # append builtin providers
  291. for provider in builtin_providers:
  292. result_providers[provider.identity.name] = UserToolProvider(
  293. id=provider.identity.name,
  294. author=provider.identity.author,
  295. name=provider.identity.name,
  296. description=I18nObject(
  297. en_US=provider.identity.description.en_US,
  298. zh_Hans=provider.identity.description.zh_Hans,
  299. ),
  300. icon=provider.identity.icon,
  301. label=I18nObject(
  302. en_US=provider.identity.label.en_US,
  303. zh_Hans=provider.identity.label.zh_Hans,
  304. ),
  305. type=UserToolProvider.ProviderType.BUILTIN,
  306. team_credentials={},
  307. is_team_authorization=False,
  308. )
  309. # get credentials schema
  310. schema = provider.get_credentials_schema()
  311. for name, value in schema.items():
  312. result_providers[provider.identity.name].team_credentials[name] = \
  313. ToolProviderCredentials.CredentialsType.default(value.type)
  314. # check if the provider need credentials
  315. if not provider.need_credentials:
  316. result_providers[provider.identity.name].is_team_authorization = True
  317. result_providers[provider.identity.name].allow_delete = False
  318. # get db builtin providers
  319. db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \
  320. filter(BuiltinToolProvider.tenant_id == tenant_id).all()
  321. for db_builtin_provider in db_builtin_providers:
  322. # add provider into providers
  323. credentials = db_builtin_provider.credentials
  324. provider_name = db_builtin_provider.provider
  325. result_providers[provider_name].is_team_authorization = True
  326. # package builtin tool provider controller
  327. controller = ToolManager.get_builtin_provider(provider_name)
  328. # init tool configuration
  329. tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
  330. # decrypt the credentials and mask the credentials
  331. decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
  332. masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials)
  333. result_providers[provider_name].team_credentials = masked_credentials
  334. # get model tool providers
  335. model_providers = ToolManager.list_model_providers(tenant_id=tenant_id)
  336. # append model providers
  337. for provider in model_providers:
  338. result_providers[f'model_provider.{provider.identity.name}'] = UserToolProvider(
  339. id=provider.identity.name,
  340. author=provider.identity.author,
  341. name=provider.identity.name,
  342. description=I18nObject(
  343. en_US=provider.identity.description.en_US,
  344. zh_Hans=provider.identity.description.zh_Hans,
  345. ),
  346. icon=provider.identity.icon,
  347. label=I18nObject(
  348. en_US=provider.identity.label.en_US,
  349. zh_Hans=provider.identity.label.zh_Hans,
  350. ),
  351. type=UserToolProvider.ProviderType.MODEL,
  352. team_credentials={},
  353. is_team_authorization=provider.is_active,
  354. )
  355. # get db api providers
  356. db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \
  357. filter(ApiToolProvider.tenant_id == tenant_id).all()
  358. for db_api_provider in db_api_providers:
  359. username = 'Anonymous'
  360. try:
  361. username = db_api_provider.user.name
  362. except Exception as e:
  363. logger.error(f'failed to get user name for api provider {db_api_provider.id}: {str(e)}')
  364. # add provider into providers
  365. credentials = db_api_provider.credentials
  366. provider_name = db_api_provider.name
  367. result_providers[provider_name] = UserToolProvider(
  368. id=db_api_provider.id,
  369. author=username,
  370. name=db_api_provider.name,
  371. description=I18nObject(
  372. en_US=db_api_provider.description,
  373. zh_Hans=db_api_provider.description,
  374. ),
  375. icon=db_api_provider.icon,
  376. label=I18nObject(
  377. en_US=db_api_provider.name,
  378. zh_Hans=db_api_provider.name,
  379. ),
  380. type=UserToolProvider.ProviderType.API,
  381. team_credentials={},
  382. is_team_authorization=True,
  383. )
  384. # package tool provider controller
  385. controller = ApiBasedToolProviderController.from_db(
  386. db_provider=db_api_provider,
  387. auth_type=ApiProviderAuthType.API_KEY if db_api_provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
  388. )
  389. # init tool configuration
  390. tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
  391. # decrypt the credentials and mask the credentials
  392. decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
  393. masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials)
  394. result_providers[provider_name].team_credentials = masked_credentials
  395. return BuiltinToolProviderSort.sort(list(result_providers.values()))
  396. @staticmethod
  397. def get_api_provider_controller(tenant_id: str, provider_id: str) -> tuple[ApiBasedToolProviderController, dict[str, Any]]:
  398. """
  399. get the api provider
  400. :param provider_name: the name of the provider
  401. :return: the provider controller, the credentials
  402. """
  403. provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
  404. ApiToolProvider.id == provider_id,
  405. ApiToolProvider.tenant_id == tenant_id,
  406. ).first()
  407. if provider is None:
  408. raise ToolProviderNotFoundError(f'api provider {provider_id} not found')
  409. controller = ApiBasedToolProviderController.from_db(
  410. provider, ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
  411. )
  412. controller.load_bundled_tools(provider.tools)
  413. return controller, provider.credentials
  414. @staticmethod
  415. def user_get_api_provider(provider: str, tenant_id: str) -> dict:
  416. """
  417. get api provider
  418. """
  419. """
  420. get tool provider
  421. """
  422. provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
  423. ApiToolProvider.tenant_id == tenant_id,
  424. ApiToolProvider.name == provider,
  425. ).first()
  426. if provider is None:
  427. raise ValueError(f'you have not added provider {provider}')
  428. try:
  429. credentials = json.loads(provider.credentials_str) or {}
  430. except:
  431. credentials = {}
  432. # package tool provider controller
  433. controller = ApiBasedToolProviderController.from_db(
  434. provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
  435. )
  436. # init tool configuration
  437. tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
  438. decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
  439. masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
  440. try:
  441. icon = json.loads(provider.icon)
  442. except:
  443. icon = {
  444. "background": "#252525",
  445. "content": "\ud83d\ude01"
  446. }
  447. return json.loads(serialize_base_model_dict({
  448. 'schema_type': provider.schema_type,
  449. 'schema': provider.schema,
  450. 'tools': provider.tools,
  451. 'icon': icon,
  452. 'description': provider.description,
  453. 'credentials': masked_credentials,
  454. 'privacy_policy': provider.privacy_policy
  455. }))