assistant_base_runner.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606
  1. import json
  2. import logging
  3. import uuid
  4. from datetime import datetime
  5. from mimetypes import guess_extension
  6. from typing import Optional, Union, cast
  7. from core.app_runner.app_runner import AppRunner
  8. from core.application_queue_manager import ApplicationQueueManager
  9. from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
  10. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  11. from core.entities.application_entities import (
  12. AgentEntity,
  13. AgentToolEntity,
  14. ApplicationGenerateEntity,
  15. AppOrchestrationConfigEntity,
  16. InvokeFrom,
  17. ModelConfigEntity,
  18. )
  19. from core.file.message_file_parser import FileTransferMethod
  20. from core.memory.token_buffer_memory import TokenBufferMemory
  21. from core.model_manager import ModelInstance
  22. from core.model_runtime.entities.llm_entities import LLMUsage
  23. from core.model_runtime.entities.message_entities import (
  24. AssistantPromptMessage,
  25. PromptMessage,
  26. PromptMessageTool,
  27. SystemPromptMessage,
  28. ToolPromptMessage,
  29. UserPromptMessage,
  30. )
  31. from core.model_runtime.entities.model_entities import ModelFeature
  32. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  33. from core.model_runtime.utils.encoders import jsonable_encoder
  34. from core.tools.entities.tool_entities import (
  35. ToolInvokeMessage,
  36. ToolInvokeMessageBinary,
  37. ToolParameter,
  38. ToolRuntimeVariablePool,
  39. )
  40. from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
  41. from core.tools.tool.tool import Tool
  42. from core.tools.tool_file_manager import ToolFileManager
  43. from core.tools.tool_manager import ToolManager
  44. from extensions.ext_database import db
  45. from models.model import Message, MessageAgentThought, MessageFile
  46. from models.tools import ToolConversationVariables
  47. logger = logging.getLogger(__name__)
  48. class BaseAssistantApplicationRunner(AppRunner):
  49. def __init__(self, tenant_id: str,
  50. application_generate_entity: ApplicationGenerateEntity,
  51. app_orchestration_config: AppOrchestrationConfigEntity,
  52. model_config: ModelConfigEntity,
  53. config: AgentEntity,
  54. queue_manager: ApplicationQueueManager,
  55. message: Message,
  56. user_id: str,
  57. memory: Optional[TokenBufferMemory] = None,
  58. prompt_messages: Optional[list[PromptMessage]] = None,
  59. variables_pool: Optional[ToolRuntimeVariablePool] = None,
  60. db_variables: Optional[ToolConversationVariables] = None,
  61. model_instance: ModelInstance = None
  62. ) -> None:
  63. """
  64. Agent runner
  65. :param tenant_id: tenant id
  66. :param app_orchestration_config: app orchestration config
  67. :param model_config: model config
  68. :param config: dataset config
  69. :param queue_manager: queue manager
  70. :param message: message
  71. :param user_id: user id
  72. :param agent_llm_callback: agent llm callback
  73. :param callback: callback
  74. :param memory: memory
  75. """
  76. self.tenant_id = tenant_id
  77. self.application_generate_entity = application_generate_entity
  78. self.app_orchestration_config = app_orchestration_config
  79. self.model_config = model_config
  80. self.config = config
  81. self.queue_manager = queue_manager
  82. self.message = message
  83. self.user_id = user_id
  84. self.memory = memory
  85. self.history_prompt_messages = self.organize_agent_history(
  86. prompt_messages=prompt_messages or []
  87. )
  88. self.variables_pool = variables_pool
  89. self.db_variables_pool = db_variables
  90. self.model_instance = model_instance
  91. # init callback
  92. self.agent_callback = DifyAgentCallbackHandler()
  93. # init dataset tools
  94. hit_callback = DatasetIndexToolCallbackHandler(
  95. queue_manager=queue_manager,
  96. app_id=self.application_generate_entity.app_id,
  97. message_id=message.id,
  98. user_id=user_id,
  99. invoke_from=self.application_generate_entity.invoke_from,
  100. )
  101. self.dataset_tools = DatasetRetrieverTool.get_dataset_tools(
  102. tenant_id=tenant_id,
  103. dataset_ids=app_orchestration_config.dataset.dataset_ids if app_orchestration_config.dataset else [],
  104. retrieve_config=app_orchestration_config.dataset.retrieve_config if app_orchestration_config.dataset else None,
  105. return_resource=app_orchestration_config.show_retrieve_source,
  106. invoke_from=application_generate_entity.invoke_from,
  107. hit_callback=hit_callback
  108. )
  109. # get how many agent thoughts have been created
  110. self.agent_thought_count = db.session.query(MessageAgentThought).filter(
  111. MessageAgentThought.message_id == self.message.id,
  112. ).count()
  113. db.session.close()
  114. # check if model supports stream tool call
  115. llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
  116. model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
  117. if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []):
  118. self.stream_tool_call = True
  119. else:
  120. self.stream_tool_call = False
  121. def _repack_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity:
  122. """
  123. Repack app orchestration config
  124. """
  125. if app_orchestration_config.prompt_template.simple_prompt_template is None:
  126. app_orchestration_config.prompt_template.simple_prompt_template = ''
  127. return app_orchestration_config
  128. def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str:
  129. """
  130. Handle tool response
  131. """
  132. result = ''
  133. for response in tool_response:
  134. if response.type == ToolInvokeMessage.MessageType.TEXT:
  135. result += response.message
  136. elif response.type == ToolInvokeMessage.MessageType.LINK:
  137. result += f"result link: {response.message}. please tell user to check it."
  138. elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
  139. response.type == ToolInvokeMessage.MessageType.IMAGE:
  140. result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now."
  141. else:
  142. result += f"tool response: {response.message}."
  143. return result
  144. def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
  145. """
  146. convert tool to prompt message tool
  147. """
  148. tool_entity = ToolManager.get_agent_tool_runtime(
  149. tenant_id=self.tenant_id,
  150. agent_tool=tool,
  151. agent_callback=self.agent_callback
  152. )
  153. tool_entity.load_variables(self.variables_pool)
  154. message_tool = PromptMessageTool(
  155. name=tool.tool_name,
  156. description=tool_entity.description.llm,
  157. parameters={
  158. "type": "object",
  159. "properties": {},
  160. "required": [],
  161. }
  162. )
  163. parameters = tool_entity.get_all_runtime_parameters()
  164. for parameter in parameters:
  165. if parameter.form != ToolParameter.ToolParameterForm.LLM:
  166. continue
  167. parameter_type = 'string'
  168. enum = []
  169. if parameter.type == ToolParameter.ToolParameterType.STRING:
  170. parameter_type = 'string'
  171. elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
  172. parameter_type = 'boolean'
  173. elif parameter.type == ToolParameter.ToolParameterType.NUMBER:
  174. parameter_type = 'number'
  175. elif parameter.type == ToolParameter.ToolParameterType.SELECT:
  176. for option in parameter.options:
  177. enum.append(option.value)
  178. parameter_type = 'string'
  179. else:
  180. raise ValueError(f"parameter type {parameter.type} is not supported")
  181. message_tool.parameters['properties'][parameter.name] = {
  182. "type": parameter_type,
  183. "description": parameter.llm_description or '',
  184. }
  185. if len(enum) > 0:
  186. message_tool.parameters['properties'][parameter.name]['enum'] = enum
  187. if parameter.required:
  188. message_tool.parameters['required'].append(parameter.name)
  189. return message_tool, tool_entity
  190. def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
  191. """
  192. convert dataset retriever tool to prompt message tool
  193. """
  194. prompt_tool = PromptMessageTool(
  195. name=tool.identity.name,
  196. description=tool.description.llm,
  197. parameters={
  198. "type": "object",
  199. "properties": {},
  200. "required": [],
  201. }
  202. )
  203. for parameter in tool.get_runtime_parameters():
  204. parameter_type = 'string'
  205. prompt_tool.parameters['properties'][parameter.name] = {
  206. "type": parameter_type,
  207. "description": parameter.llm_description or '',
  208. }
  209. if parameter.required:
  210. if parameter.name not in prompt_tool.parameters['required']:
  211. prompt_tool.parameters['required'].append(parameter.name)
  212. return prompt_tool
  213. def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
  214. """
  215. update prompt message tool
  216. """
  217. # try to get tool runtime parameters
  218. tool_runtime_parameters = tool.get_runtime_parameters() or []
  219. for parameter in tool_runtime_parameters:
  220. if parameter.form != ToolParameter.ToolParameterForm.LLM:
  221. continue
  222. parameter_type = 'string'
  223. enum = []
  224. if parameter.type == ToolParameter.ToolParameterType.STRING:
  225. parameter_type = 'string'
  226. elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
  227. parameter_type = 'boolean'
  228. elif parameter.type == ToolParameter.ToolParameterType.NUMBER:
  229. parameter_type = 'number'
  230. elif parameter.type == ToolParameter.ToolParameterType.SELECT:
  231. for option in parameter.options:
  232. enum.append(option.value)
  233. parameter_type = 'string'
  234. else:
  235. raise ValueError(f"parameter type {parameter.type} is not supported")
  236. prompt_tool.parameters['properties'][parameter.name] = {
  237. "type": parameter_type,
  238. "description": parameter.llm_description or '',
  239. }
  240. if len(enum) > 0:
  241. prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
  242. if parameter.required:
  243. if parameter.name not in prompt_tool.parameters['required']:
  244. prompt_tool.parameters['required'].append(parameter.name)
  245. return prompt_tool
  246. def extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]:
  247. """
  248. Extract tool response binary
  249. """
  250. result = []
  251. for response in tool_response:
  252. if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
  253. response.type == ToolInvokeMessage.MessageType.IMAGE:
  254. result.append(ToolInvokeMessageBinary(
  255. mimetype=response.meta.get('mime_type', 'octet/stream'),
  256. url=response.message,
  257. save_as=response.save_as,
  258. ))
  259. elif response.type == ToolInvokeMessage.MessageType.BLOB:
  260. result.append(ToolInvokeMessageBinary(
  261. mimetype=response.meta.get('mime_type', 'octet/stream'),
  262. url=response.message,
  263. save_as=response.save_as,
  264. ))
  265. elif response.type == ToolInvokeMessage.MessageType.LINK:
  266. # check if there is a mime type in meta
  267. if response.meta and 'mime_type' in response.meta:
  268. result.append(ToolInvokeMessageBinary(
  269. mimetype=response.meta.get('mime_type', 'octet/stream') if response.meta else 'octet/stream',
  270. url=response.message,
  271. save_as=response.save_as,
  272. ))
  273. return result
  274. def create_message_files(self, messages: list[ToolInvokeMessageBinary]) -> list[tuple[MessageFile, bool]]:
  275. """
  276. Create message file
  277. :param messages: messages
  278. :return: message files, should save as variable
  279. """
  280. result = []
  281. for message in messages:
  282. file_type = 'bin'
  283. if 'image' in message.mimetype:
  284. file_type = 'image'
  285. elif 'video' in message.mimetype:
  286. file_type = 'video'
  287. elif 'audio' in message.mimetype:
  288. file_type = 'audio'
  289. elif 'text' in message.mimetype:
  290. file_type = 'text'
  291. elif 'pdf' in message.mimetype:
  292. file_type = 'pdf'
  293. elif 'zip' in message.mimetype:
  294. file_type = 'archive'
  295. # ...
  296. invoke_from = self.application_generate_entity.invoke_from
  297. message_file = MessageFile(
  298. message_id=self.message.id,
  299. type=file_type,
  300. transfer_method=FileTransferMethod.TOOL_FILE.value,
  301. belongs_to='assistant',
  302. url=message.url,
  303. upload_file_id=None,
  304. created_by_role=('account'if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'),
  305. created_by=self.user_id,
  306. )
  307. db.session.add(message_file)
  308. db.session.commit()
  309. db.session.refresh(message_file)
  310. result.append((
  311. message_file,
  312. message.save_as
  313. ))
  314. db.session.close()
  315. return result
  316. def create_agent_thought(self, message_id: str, message: str,
  317. tool_name: str, tool_input: str, messages_ids: list[str]
  318. ) -> MessageAgentThought:
  319. """
  320. Create agent thought
  321. """
  322. thought = MessageAgentThought(
  323. message_id=message_id,
  324. message_chain_id=None,
  325. thought='',
  326. tool=tool_name,
  327. tool_labels_str='{}',
  328. tool_input=tool_input,
  329. message=message,
  330. message_token=0,
  331. message_unit_price=0,
  332. message_price_unit=0,
  333. message_files=json.dumps(messages_ids) if messages_ids else '',
  334. answer='',
  335. observation='',
  336. answer_token=0,
  337. answer_unit_price=0,
  338. answer_price_unit=0,
  339. tokens=0,
  340. total_price=0,
  341. position=self.agent_thought_count + 1,
  342. currency='USD',
  343. latency=0,
  344. created_by_role='account',
  345. created_by=self.user_id,
  346. )
  347. db.session.add(thought)
  348. db.session.commit()
  349. db.session.refresh(thought)
  350. db.session.close()
  351. self.agent_thought_count += 1
  352. return thought
  353. def save_agent_thought(self,
  354. agent_thought: MessageAgentThought,
  355. tool_name: str,
  356. tool_input: Union[str, dict],
  357. thought: str,
  358. observation: str,
  359. answer: str,
  360. messages_ids: list[str],
  361. llm_usage: LLMUsage = None) -> MessageAgentThought:
  362. """
  363. Save agent thought
  364. """
  365. agent_thought = db.session.query(MessageAgentThought).filter(
  366. MessageAgentThought.id == agent_thought.id
  367. ).first()
  368. if thought is not None:
  369. agent_thought.thought = thought
  370. if tool_name is not None:
  371. agent_thought.tool = tool_name
  372. if tool_input is not None:
  373. if isinstance(tool_input, dict):
  374. try:
  375. tool_input = json.dumps(tool_input, ensure_ascii=False)
  376. except Exception as e:
  377. tool_input = json.dumps(tool_input)
  378. agent_thought.tool_input = tool_input
  379. if observation is not None:
  380. agent_thought.observation = observation
  381. if answer is not None:
  382. agent_thought.answer = answer
  383. if messages_ids is not None and len(messages_ids) > 0:
  384. agent_thought.message_files = json.dumps(messages_ids)
  385. if llm_usage:
  386. agent_thought.message_token = llm_usage.prompt_tokens
  387. agent_thought.message_price_unit = llm_usage.prompt_price_unit
  388. agent_thought.message_unit_price = llm_usage.prompt_unit_price
  389. agent_thought.answer_token = llm_usage.completion_tokens
  390. agent_thought.answer_price_unit = llm_usage.completion_price_unit
  391. agent_thought.answer_unit_price = llm_usage.completion_unit_price
  392. agent_thought.tokens = llm_usage.total_tokens
  393. agent_thought.total_price = llm_usage.total_price
  394. # check if tool labels is not empty
  395. labels = agent_thought.tool_labels or {}
  396. tools = agent_thought.tool.split(';') if agent_thought.tool else []
  397. for tool in tools:
  398. if not tool:
  399. continue
  400. if tool not in labels:
  401. tool_label = ToolManager.get_tool_label(tool)
  402. if tool_label:
  403. labels[tool] = tool_label.to_dict()
  404. else:
  405. labels[tool] = {'en_US': tool, 'zh_Hans': tool}
  406. agent_thought.tool_labels_str = json.dumps(labels)
  407. db.session.commit()
  408. db.session.close()
  409. def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]:
  410. """
  411. Transform tool message into agent thought
  412. """
  413. result = []
  414. for message in messages:
  415. if message.type == ToolInvokeMessage.MessageType.TEXT:
  416. result.append(message)
  417. elif message.type == ToolInvokeMessage.MessageType.LINK:
  418. result.append(message)
  419. elif message.type == ToolInvokeMessage.MessageType.IMAGE:
  420. # try to download image
  421. try:
  422. file = ToolFileManager.create_file_by_url(user_id=self.user_id, tenant_id=self.tenant_id,
  423. conversation_id=self.message.conversation_id,
  424. file_url=message.message)
  425. url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
  426. result.append(ToolInvokeMessage(
  427. type=ToolInvokeMessage.MessageType.IMAGE_LINK,
  428. message=url,
  429. save_as=message.save_as,
  430. meta=message.meta.copy() if message.meta is not None else {},
  431. ))
  432. except Exception as e:
  433. logger.exception(e)
  434. result.append(ToolInvokeMessage(
  435. type=ToolInvokeMessage.MessageType.TEXT,
  436. message=f"Failed to download image: {message.message}, you can try to download it yourself.",
  437. meta=message.meta.copy() if message.meta is not None else {},
  438. save_as=message.save_as,
  439. ))
  440. elif message.type == ToolInvokeMessage.MessageType.BLOB:
  441. # get mime type and save blob to storage
  442. mimetype = message.meta.get('mime_type', 'octet/stream')
  443. # if message is str, encode it to bytes
  444. if isinstance(message.message, str):
  445. message.message = message.message.encode('utf-8')
  446. file = ToolFileManager.create_file_by_raw(user_id=self.user_id, tenant_id=self.tenant_id,
  447. conversation_id=self.message.conversation_id,
  448. file_binary=message.message,
  449. mimetype=mimetype)
  450. url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".bin"}'
  451. # check if file is image
  452. if 'image' in mimetype:
  453. result.append(ToolInvokeMessage(
  454. type=ToolInvokeMessage.MessageType.IMAGE_LINK,
  455. message=url,
  456. save_as=message.save_as,
  457. meta=message.meta.copy() if message.meta is not None else {},
  458. ))
  459. else:
  460. result.append(ToolInvokeMessage(
  461. type=ToolInvokeMessage.MessageType.LINK,
  462. message=url,
  463. save_as=message.save_as,
  464. meta=message.meta.copy() if message.meta is not None else {},
  465. ))
  466. else:
  467. result.append(message)
  468. return result
  469. def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
  470. """
  471. convert tool variables to db variables
  472. """
  473. db_variables = db.session.query(ToolConversationVariables).filter(
  474. ToolConversationVariables.conversation_id == self.message.conversation_id,
  475. ).first()
  476. db_variables.updated_at = datetime.utcnow()
  477. db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
  478. db.session.commit()
  479. db.session.close()
  480. def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  481. """
  482. Organize agent history
  483. """
  484. result = []
  485. # check if there is a system message in the beginning of the conversation
  486. if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
  487. result.append(prompt_messages[0])
  488. messages: list[Message] = db.session.query(Message).filter(
  489. Message.conversation_id == self.message.conversation_id,
  490. ).order_by(Message.created_at.asc()).all()
  491. for message in messages:
  492. result.append(UserPromptMessage(content=message.query))
  493. agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
  494. if agent_thoughts:
  495. for agent_thought in agent_thoughts:
  496. tools = agent_thought.tool
  497. if tools:
  498. tools = tools.split(';')
  499. tool_calls: list[AssistantPromptMessage.ToolCall] = []
  500. tool_call_response: list[ToolPromptMessage] = []
  501. try:
  502. tool_inputs = json.loads(agent_thought.tool_input)
  503. except Exception as e:
  504. logging.warning("tool execution error: {}, tool_input: {}.".format(str(e), agent_thought.tool_input))
  505. tool_inputs = { agent_thought.tool: agent_thought.tool_input }
  506. for tool in tools:
  507. # generate a uuid for tool call
  508. tool_call_id = str(uuid.uuid4())
  509. tool_calls.append(AssistantPromptMessage.ToolCall(
  510. id=tool_call_id,
  511. type='function',
  512. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  513. name=tool,
  514. arguments=json.dumps(tool_inputs.get(tool, {})),
  515. )
  516. ))
  517. tool_call_response.append(ToolPromptMessage(
  518. content=agent_thought.observation,
  519. name=tool,
  520. tool_call_id=tool_call_id,
  521. ))
  522. result.extend([
  523. AssistantPromptMessage(
  524. content=agent_thought.thought,
  525. tool_calls=tool_calls,
  526. ),
  527. *tool_call_response
  528. ])
  529. if not tools:
  530. result.append(AssistantPromptMessage(content=agent_thought.thought))
  531. else:
  532. if message.answer:
  533. result.append(AssistantPromptMessage(content=message.answer))
  534. db.session.close()
  535. return result