tool_engine.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. from copy import deepcopy
  2. from datetime import datetime, timezone
  3. from typing import Union
  4. from core.app.entities.app_invoke_entities import InvokeFrom
  5. from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
  6. from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
  7. from core.file.file_obj import FileTransferMethod
  8. from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolInvokeMeta, ToolParameter
  9. from core.tools.errors import (
  10. ToolEngineInvokeError,
  11. ToolInvokeError,
  12. ToolNotFoundError,
  13. ToolNotSupportedError,
  14. ToolParameterValidationError,
  15. ToolProviderCredentialValidationError,
  16. ToolProviderNotFoundError,
  17. )
  18. from core.tools.tool.tool import Tool
  19. from core.tools.utils.message_transformer import ToolFileMessageTransformer
  20. from extensions.ext_database import db
  21. from models.model import Message, MessageFile
  22. class ToolEngine:
  23. """
  24. Tool runtime engine take care of the tool executions.
  25. """
  26. @staticmethod
  27. def agent_invoke(tool: Tool, tool_parameters: Union[str, dict],
  28. user_id: str, tenant_id: str, message: Message, invoke_from: InvokeFrom,
  29. agent_tool_callback: DifyAgentCallbackHandler) \
  30. -> tuple[str, list[tuple[MessageFile, bool]], ToolInvokeMeta]:
  31. """
  32. Agent invokes the tool with the given arguments.
  33. """
  34. # check if arguments is a string
  35. if isinstance(tool_parameters, str):
  36. # check if this tool has only one parameter
  37. parameters = [
  38. parameter for parameter in tool.parameters
  39. if parameter.form == ToolParameter.ToolParameterForm.LLM
  40. ]
  41. if parameters and len(parameters) == 1:
  42. tool_parameters = {
  43. parameters[0].name: tool_parameters
  44. }
  45. else:
  46. raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
  47. # invoke the tool
  48. try:
  49. # hit the callback handler
  50. agent_tool_callback.on_tool_start(
  51. tool_name=tool.identity.name,
  52. tool_inputs=tool_parameters
  53. )
  54. meta, response = ToolEngine._invoke(tool, tool_parameters, user_id)
  55. response = ToolFileMessageTransformer.transform_tool_invoke_messages(
  56. messages=response,
  57. user_id=user_id,
  58. tenant_id=tenant_id,
  59. conversation_id=message.conversation_id
  60. )
  61. # extract binary data from tool invoke message
  62. binary_files = ToolEngine._extract_tool_response_binary(response)
  63. # create message file
  64. message_files = ToolEngine._create_message_files(
  65. tool_messages=binary_files,
  66. agent_message=message,
  67. invoke_from=invoke_from,
  68. user_id=user_id
  69. )
  70. plain_text = ToolEngine._convert_tool_response_to_str(response)
  71. # hit the callback handler
  72. agent_tool_callback.on_tool_end(
  73. tool_name=tool.identity.name,
  74. tool_inputs=tool_parameters,
  75. tool_outputs=plain_text
  76. )
  77. # transform tool invoke message to get LLM friendly message
  78. return plain_text, message_files, meta
  79. except ToolProviderCredentialValidationError as e:
  80. error_response = "Please check your tool provider credentials"
  81. agent_tool_callback.on_tool_error(e)
  82. except (
  83. ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
  84. ) as e:
  85. error_response = f"there is not a tool named {tool.identity.name}"
  86. agent_tool_callback.on_tool_error(e)
  87. except (
  88. ToolParameterValidationError
  89. ) as e:
  90. error_response = f"tool parameters validation error: {e}, please check your tool parameters"
  91. agent_tool_callback.on_tool_error(e)
  92. except ToolInvokeError as e:
  93. error_response = f"tool invoke error: {e}"
  94. agent_tool_callback.on_tool_error(e)
  95. except ToolEngineInvokeError as e:
  96. meta = e.args[0]
  97. error_response = f"tool invoke error: {meta.error}"
  98. agent_tool_callback.on_tool_error(e)
  99. return error_response, [], meta
  100. except Exception as e:
  101. error_response = f"unknown error: {e}"
  102. agent_tool_callback.on_tool_error(e)
  103. return error_response, [], ToolInvokeMeta.error_instance(error_response)
  104. @staticmethod
  105. def workflow_invoke(tool: Tool, tool_parameters: dict,
  106. user_id: str, workflow_id: str,
  107. workflow_tool_callback: DifyWorkflowCallbackHandler) \
  108. -> list[ToolInvokeMessage]:
  109. """
  110. Workflow invokes the tool with the given arguments.
  111. """
  112. try:
  113. # hit the callback handler
  114. workflow_tool_callback.on_tool_start(
  115. tool_name=tool.identity.name,
  116. tool_inputs=tool_parameters
  117. )
  118. response = tool.invoke(user_id, tool_parameters)
  119. # hit the callback handler
  120. workflow_tool_callback.on_tool_end(
  121. tool_name=tool.identity.name,
  122. tool_inputs=tool_parameters,
  123. tool_outputs=response
  124. )
  125. return response
  126. except Exception as e:
  127. workflow_tool_callback.on_tool_error(e)
  128. raise e
  129. @staticmethod
  130. def _invoke(tool: Tool, tool_parameters: dict, user_id: str) \
  131. -> tuple[ToolInvokeMeta, list[ToolInvokeMessage]]:
  132. """
  133. Invoke the tool with the given arguments.
  134. """
  135. started_at = datetime.now(timezone.utc)
  136. meta = ToolInvokeMeta(time_cost=0.0, error=None, tool_config={
  137. 'tool_name': tool.identity.name,
  138. 'tool_provider': tool.identity.provider,
  139. 'tool_provider_type': tool.tool_provider_type().value,
  140. 'tool_parameters': deepcopy(tool.runtime.runtime_parameters),
  141. 'tool_icon': tool.identity.icon
  142. })
  143. try:
  144. response = tool.invoke(user_id, tool_parameters)
  145. except Exception as e:
  146. meta.error = str(e)
  147. raise ToolEngineInvokeError(meta)
  148. finally:
  149. ended_at = datetime.now(timezone.utc)
  150. meta.time_cost = (ended_at - started_at).total_seconds()
  151. return meta, response
  152. @staticmethod
  153. def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str:
  154. """
  155. Handle tool response
  156. """
  157. result = ''
  158. for response in tool_response:
  159. if response.type == ToolInvokeMessage.MessageType.TEXT:
  160. result += response.message
  161. elif response.type == ToolInvokeMessage.MessageType.LINK:
  162. result += f"result link: {response.message}. please tell user to check it."
  163. elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
  164. response.type == ToolInvokeMessage.MessageType.IMAGE:
  165. result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now."
  166. else:
  167. result += f"tool response: {response.message}."
  168. return result
  169. @staticmethod
  170. def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]:
  171. """
  172. Extract tool response binary
  173. """
  174. result = []
  175. for response in tool_response:
  176. if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
  177. response.type == ToolInvokeMessage.MessageType.IMAGE:
  178. result.append(ToolInvokeMessageBinary(
  179. mimetype=response.meta.get('mime_type', 'octet/stream'),
  180. url=response.message,
  181. save_as=response.save_as,
  182. ))
  183. elif response.type == ToolInvokeMessage.MessageType.BLOB:
  184. result.append(ToolInvokeMessageBinary(
  185. mimetype=response.meta.get('mime_type', 'octet/stream'),
  186. url=response.message,
  187. save_as=response.save_as,
  188. ))
  189. elif response.type == ToolInvokeMessage.MessageType.LINK:
  190. # check if there is a mime type in meta
  191. if response.meta and 'mime_type' in response.meta:
  192. result.append(ToolInvokeMessageBinary(
  193. mimetype=response.meta.get('mime_type', 'octet/stream') if response.meta else 'octet/stream',
  194. url=response.message,
  195. save_as=response.save_as,
  196. ))
  197. return result
  198. @staticmethod
  199. def _create_message_files(
  200. tool_messages: list[ToolInvokeMessageBinary],
  201. agent_message: Message,
  202. invoke_from: InvokeFrom,
  203. user_id: str
  204. ) -> list[tuple[MessageFile, bool]]:
  205. """
  206. Create message file
  207. :param messages: messages
  208. :return: message files, should save as variable
  209. """
  210. result = []
  211. for message in tool_messages:
  212. file_type = 'bin'
  213. if 'image' in message.mimetype:
  214. file_type = 'image'
  215. elif 'video' in message.mimetype:
  216. file_type = 'video'
  217. elif 'audio' in message.mimetype:
  218. file_type = 'audio'
  219. elif 'text' in message.mimetype:
  220. file_type = 'text'
  221. elif 'pdf' in message.mimetype:
  222. file_type = 'pdf'
  223. elif 'zip' in message.mimetype:
  224. file_type = 'archive'
  225. # ...
  226. message_file = MessageFile(
  227. message_id=agent_message.id,
  228. type=file_type,
  229. transfer_method=FileTransferMethod.TOOL_FILE.value,
  230. belongs_to='assistant',
  231. url=message.url,
  232. upload_file_id=None,
  233. created_by_role=('account'if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'),
  234. created_by=user_id,
  235. )
  236. db.session.add(message_file)
  237. db.session.commit()
  238. db.session.refresh(message_file)
  239. result.append((
  240. message_file,
  241. message.save_as
  242. ))
  243. db.session.close()
  244. return result