assistant_base_runner.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592
  1. import json
  2. import logging
  3. from datetime import datetime
  4. from mimetypes import guess_extension
  5. from typing import List, Optional, Tuple, Union, cast
  6. from core.app_runner.app_runner import AppRunner
  7. from core.application_queue_manager import ApplicationQueueManager
  8. from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
  9. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  10. from core.entities.application_entities import (
  11. AgentEntity,
  12. AgentToolEntity,
  13. ApplicationGenerateEntity,
  14. AppOrchestrationConfigEntity,
  15. InvokeFrom,
  16. ModelConfigEntity,
  17. )
  18. from core.file.message_file_parser import FileTransferMethod
  19. from core.memory.token_buffer_memory import TokenBufferMemory
  20. from core.model_manager import ModelInstance
  21. from core.model_runtime.entities.llm_entities import LLMUsage
  22. from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
  23. from core.model_runtime.entities.model_entities import ModelFeature
  24. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  25. from core.model_runtime.utils.encoders import jsonable_encoder
  26. from core.tools.entities.tool_entities import (
  27. ToolInvokeMessage,
  28. ToolInvokeMessageBinary,
  29. ToolParameter,
  30. ToolRuntimeVariablePool,
  31. )
  32. from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
  33. from core.tools.tool.tool import Tool
  34. from core.tools.tool_file_manager import ToolFileManager
  35. from core.tools.tool_manager import ToolManager
  36. from extensions.ext_database import db
  37. from models.model import Message, MessageAgentThought, MessageFile
  38. from models.tools import ToolConversationVariables
  39. logger = logging.getLogger(__name__)
  40. class BaseAssistantApplicationRunner(AppRunner):
  41. def __init__(self, tenant_id: str,
  42. application_generate_entity: ApplicationGenerateEntity,
  43. app_orchestration_config: AppOrchestrationConfigEntity,
  44. model_config: ModelConfigEntity,
  45. config: AgentEntity,
  46. queue_manager: ApplicationQueueManager,
  47. message: Message,
  48. user_id: str,
  49. memory: Optional[TokenBufferMemory] = None,
  50. prompt_messages: Optional[List[PromptMessage]] = None,
  51. variables_pool: Optional[ToolRuntimeVariablePool] = None,
  52. db_variables: Optional[ToolConversationVariables] = None,
  53. model_instance: ModelInstance = None
  54. ) -> None:
  55. """
  56. Agent runner
  57. :param tenant_id: tenant id
  58. :param app_orchestration_config: app orchestration config
  59. :param model_config: model config
  60. :param config: dataset config
  61. :param queue_manager: queue manager
  62. :param message: message
  63. :param user_id: user id
  64. :param agent_llm_callback: agent llm callback
  65. :param callback: callback
  66. :param memory: memory
  67. """
  68. self.tenant_id = tenant_id
  69. self.application_generate_entity = application_generate_entity
  70. self.app_orchestration_config = app_orchestration_config
  71. self.model_config = model_config
  72. self.config = config
  73. self.queue_manager = queue_manager
  74. self.message = message
  75. self.user_id = user_id
  76. self.memory = memory
  77. self.history_prompt_messages = prompt_messages
  78. self.variables_pool = variables_pool
  79. self.db_variables_pool = db_variables
  80. self.model_instance = model_instance
  81. # init callback
  82. self.agent_callback = DifyAgentCallbackHandler()
  83. # init dataset tools
  84. hit_callback = DatasetIndexToolCallbackHandler(
  85. queue_manager=queue_manager,
  86. app_id=self.application_generate_entity.app_id,
  87. message_id=message.id,
  88. user_id=user_id,
  89. invoke_from=self.application_generate_entity.invoke_from,
  90. )
  91. self.dataset_tools = DatasetRetrieverTool.get_dataset_tools(
  92. tenant_id=tenant_id,
  93. dataset_ids=app_orchestration_config.dataset.dataset_ids if app_orchestration_config.dataset else [],
  94. retrieve_config=app_orchestration_config.dataset.retrieve_config if app_orchestration_config.dataset else None,
  95. return_resource=app_orchestration_config.show_retrieve_source,
  96. invoke_from=application_generate_entity.invoke_from,
  97. hit_callback=hit_callback
  98. )
  99. # get how many agent thoughts have been created
  100. self.agent_thought_count = db.session.query(MessageAgentThought).filter(
  101. MessageAgentThought.message_id == self.message.id,
  102. ).count()
  103. # check if model supports stream tool call
  104. llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
  105. model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
  106. if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []):
  107. self.stream_tool_call = True
  108. else:
  109. self.stream_tool_call = False
  110. def _repack_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity:
  111. """
  112. Repack app orchestration config
  113. """
  114. if app_orchestration_config.prompt_template.simple_prompt_template is None:
  115. app_orchestration_config.prompt_template.simple_prompt_template = ''
  116. return app_orchestration_config
  117. def _convert_tool_response_to_str(self, tool_response: List[ToolInvokeMessage]) -> str:
  118. """
  119. Handle tool response
  120. """
  121. result = ''
  122. for response in tool_response:
  123. if response.type == ToolInvokeMessage.MessageType.TEXT:
  124. result += response.message
  125. elif response.type == ToolInvokeMessage.MessageType.LINK:
  126. result += f"result link: {response.message}. please tell user to check it."
  127. elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
  128. response.type == ToolInvokeMessage.MessageType.IMAGE:
  129. result += "image has been created and sent to user already, you should tell user to check it now."
  130. else:
  131. result += f"tool response: {response.message}."
  132. return result
  133. def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> Tuple[PromptMessageTool, Tool]:
  134. """
  135. convert tool to prompt message tool
  136. """
  137. tool_entity = ToolManager.get_tool_runtime(
  138. provider_type=tool.provider_type, provider_name=tool.provider_id, tool_name=tool.tool_name,
  139. tenant_id=self.application_generate_entity.tenant_id,
  140. agent_callback=self.agent_callback
  141. )
  142. tool_entity.load_variables(self.variables_pool)
  143. message_tool = PromptMessageTool(
  144. name=tool.tool_name,
  145. description=tool_entity.description.llm,
  146. parameters={
  147. "type": "object",
  148. "properties": {},
  149. "required": [],
  150. }
  151. )
  152. runtime_parameters = {}
  153. parameters = tool_entity.parameters or []
  154. user_parameters = tool_entity.get_runtime_parameters() or []
  155. # override parameters
  156. for parameter in user_parameters:
  157. # check if parameter in tool parameters
  158. found = False
  159. for tool_parameter in parameters:
  160. if tool_parameter.name == parameter.name:
  161. found = True
  162. break
  163. if found:
  164. # override parameter
  165. tool_parameter.type = parameter.type
  166. tool_parameter.form = parameter.form
  167. tool_parameter.required = parameter.required
  168. tool_parameter.default = parameter.default
  169. tool_parameter.options = parameter.options
  170. tool_parameter.llm_description = parameter.llm_description
  171. else:
  172. # add new parameter
  173. parameters.append(parameter)
  174. for parameter in parameters:
  175. parameter_type = 'string'
  176. enum = []
  177. if parameter.type == ToolParameter.ToolParameterType.STRING:
  178. parameter_type = 'string'
  179. elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
  180. parameter_type = 'boolean'
  181. elif parameter.type == ToolParameter.ToolParameterType.NUMBER:
  182. parameter_type = 'number'
  183. elif parameter.type == ToolParameter.ToolParameterType.SELECT:
  184. for option in parameter.options:
  185. enum.append(option.value)
  186. parameter_type = 'string'
  187. else:
  188. raise ValueError(f"parameter type {parameter.type} is not supported")
  189. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  190. # get tool parameter from form
  191. tool_parameter_config = tool.tool_parameters.get(parameter.name)
  192. if not tool_parameter_config:
  193. # get default value
  194. tool_parameter_config = parameter.default
  195. if not tool_parameter_config and parameter.required:
  196. raise ValueError(f"tool parameter {parameter.name} not found in tool config")
  197. if parameter.type == ToolParameter.ToolParameterType.SELECT:
  198. # check if tool_parameter_config in options
  199. options = list(map(lambda x: x.value, parameter.options))
  200. if tool_parameter_config not in options:
  201. raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}")
  202. # convert tool parameter config to correct type
  203. try:
  204. if parameter.type == ToolParameter.ToolParameterType.NUMBER:
  205. # check if tool parameter is integer
  206. if isinstance(tool_parameter_config, int):
  207. tool_parameter_config = tool_parameter_config
  208. elif isinstance(tool_parameter_config, float):
  209. tool_parameter_config = tool_parameter_config
  210. elif isinstance(tool_parameter_config, str):
  211. if '.' in tool_parameter_config:
  212. tool_parameter_config = float(tool_parameter_config)
  213. else:
  214. tool_parameter_config = int(tool_parameter_config)
  215. elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
  216. tool_parameter_config = bool(tool_parameter_config)
  217. elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]:
  218. tool_parameter_config = str(tool_parameter_config)
  219. elif parameter.type == ToolParameter.ToolParameterType:
  220. tool_parameter_config = str(tool_parameter_config)
  221. except Exception as e:
  222. raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type")
  223. # save tool parameter to tool entity memory
  224. runtime_parameters[parameter.name] = tool_parameter_config
  225. elif parameter.form == ToolParameter.ToolParameterForm.LLM:
  226. message_tool.parameters['properties'][parameter.name] = {
  227. "type": parameter_type,
  228. "description": parameter.llm_description or '',
  229. }
  230. if len(enum) > 0:
  231. message_tool.parameters['properties'][parameter.name]['enum'] = enum
  232. if parameter.required:
  233. message_tool.parameters['required'].append(parameter.name)
  234. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  235. return message_tool, tool_entity
  236. def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
  237. """
  238. convert dataset retriever tool to prompt message tool
  239. """
  240. prompt_tool = PromptMessageTool(
  241. name=tool.identity.name,
  242. description=tool.description.llm,
  243. parameters={
  244. "type": "object",
  245. "properties": {},
  246. "required": [],
  247. }
  248. )
  249. for parameter in tool.get_runtime_parameters():
  250. parameter_type = 'string'
  251. prompt_tool.parameters['properties'][parameter.name] = {
  252. "type": parameter_type,
  253. "description": parameter.llm_description or '',
  254. }
  255. if parameter.required:
  256. if parameter.name not in prompt_tool.parameters['required']:
  257. prompt_tool.parameters['required'].append(parameter.name)
  258. return prompt_tool
  259. def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
  260. """
  261. update prompt message tool
  262. """
  263. # try to get tool runtime parameters
  264. tool_runtime_parameters = tool.get_runtime_parameters() or []
  265. for parameter in tool_runtime_parameters:
  266. parameter_type = 'string'
  267. enum = []
  268. if parameter.type == ToolParameter.ToolParameterType.STRING:
  269. parameter_type = 'string'
  270. elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
  271. parameter_type = 'boolean'
  272. elif parameter.type == ToolParameter.ToolParameterType.NUMBER:
  273. parameter_type = 'number'
  274. elif parameter.type == ToolParameter.ToolParameterType.SELECT:
  275. for option in parameter.options:
  276. enum.append(option.value)
  277. parameter_type = 'string'
  278. else:
  279. raise ValueError(f"parameter type {parameter.type} is not supported")
  280. if parameter.form == ToolParameter.ToolParameterForm.LLM:
  281. prompt_tool.parameters['properties'][parameter.name] = {
  282. "type": parameter_type,
  283. "description": parameter.llm_description or '',
  284. }
  285. if len(enum) > 0:
  286. prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
  287. if parameter.required:
  288. if parameter.name not in prompt_tool.parameters['required']:
  289. prompt_tool.parameters['required'].append(parameter.name)
  290. return prompt_tool
  291. def extract_tool_response_binary(self, tool_response: List[ToolInvokeMessage]) -> List[ToolInvokeMessageBinary]:
  292. """
  293. Extract tool response binary
  294. """
  295. result = []
  296. for response in tool_response:
  297. if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
  298. response.type == ToolInvokeMessage.MessageType.IMAGE:
  299. result.append(ToolInvokeMessageBinary(
  300. mimetype=response.meta.get('mime_type', 'octet/stream'),
  301. url=response.message,
  302. save_as=response.save_as,
  303. ))
  304. elif response.type == ToolInvokeMessage.MessageType.BLOB:
  305. result.append(ToolInvokeMessageBinary(
  306. mimetype=response.meta.get('mime_type', 'octet/stream'),
  307. url=response.message,
  308. save_as=response.save_as,
  309. ))
  310. elif response.type == ToolInvokeMessage.MessageType.LINK:
  311. # check if there is a mime type in meta
  312. if response.meta and 'mime_type' in response.meta:
  313. result.append(ToolInvokeMessageBinary(
  314. mimetype=response.meta.get('mime_type', 'octet/stream') if response.meta else 'octet/stream',
  315. url=response.message,
  316. save_as=response.save_as,
  317. ))
  318. return result
  319. def create_message_files(self, messages: List[ToolInvokeMessageBinary]) -> List[Tuple[MessageFile, bool]]:
  320. """
  321. Create message file
  322. :param messages: messages
  323. :return: message files, should save as variable
  324. """
  325. result = []
  326. for message in messages:
  327. file_type = 'bin'
  328. if 'image' in message.mimetype:
  329. file_type = 'image'
  330. elif 'video' in message.mimetype:
  331. file_type = 'video'
  332. elif 'audio' in message.mimetype:
  333. file_type = 'audio'
  334. elif 'text' in message.mimetype:
  335. file_type = 'text'
  336. elif 'pdf' in message.mimetype:
  337. file_type = 'pdf'
  338. elif 'zip' in message.mimetype:
  339. file_type = 'archive'
  340. # ...
  341. invoke_from = self.application_generate_entity.invoke_from
  342. message_file = MessageFile(
  343. message_id=self.message.id,
  344. type=file_type,
  345. transfer_method=FileTransferMethod.TOOL_FILE.value,
  346. belongs_to='assistant',
  347. url=message.url,
  348. upload_file_id=None,
  349. created_by_role=('account'if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'),
  350. created_by=self.user_id,
  351. )
  352. db.session.add(message_file)
  353. result.append((
  354. message_file,
  355. message.save_as
  356. ))
  357. db.session.commit()
  358. return result
  359. def create_agent_thought(self, message_id: str, message: str,
  360. tool_name: str, tool_input: str, messages_ids: List[str]
  361. ) -> MessageAgentThought:
  362. """
  363. Create agent thought
  364. """
  365. thought = MessageAgentThought(
  366. message_id=message_id,
  367. message_chain_id=None,
  368. thought='',
  369. tool=tool_name,
  370. tool_labels_str='{}',
  371. tool_input=tool_input,
  372. message=message,
  373. message_token=0,
  374. message_unit_price=0,
  375. message_price_unit=0,
  376. message_files=json.dumps(messages_ids) if messages_ids else '',
  377. answer='',
  378. observation='',
  379. answer_token=0,
  380. answer_unit_price=0,
  381. answer_price_unit=0,
  382. tokens=0,
  383. total_price=0,
  384. position=self.agent_thought_count + 1,
  385. currency='USD',
  386. latency=0,
  387. created_by_role='account',
  388. created_by=self.user_id,
  389. )
  390. db.session.add(thought)
  391. db.session.commit()
  392. self.agent_thought_count += 1
  393. return thought
  394. def save_agent_thought(self,
  395. agent_thought: MessageAgentThought,
  396. tool_name: str,
  397. tool_input: Union[str, dict],
  398. thought: str,
  399. observation: str,
  400. answer: str,
  401. messages_ids: List[str],
  402. llm_usage: LLMUsage = None) -> MessageAgentThought:
  403. """
  404. Save agent thought
  405. """
  406. if thought is not None:
  407. agent_thought.thought = thought
  408. if tool_name is not None:
  409. agent_thought.tool = tool_name
  410. if tool_input is not None:
  411. if isinstance(tool_input, dict):
  412. try:
  413. tool_input = json.dumps(tool_input, ensure_ascii=False)
  414. except Exception as e:
  415. tool_input = json.dumps(tool_input)
  416. agent_thought.tool_input = tool_input
  417. if observation is not None:
  418. agent_thought.observation = observation
  419. if answer is not None:
  420. agent_thought.answer = answer
  421. if messages_ids is not None and len(messages_ids) > 0:
  422. agent_thought.message_files = json.dumps(messages_ids)
  423. if llm_usage:
  424. agent_thought.message_token = llm_usage.prompt_tokens
  425. agent_thought.message_price_unit = llm_usage.prompt_price_unit
  426. agent_thought.message_unit_price = llm_usage.prompt_unit_price
  427. agent_thought.answer_token = llm_usage.completion_tokens
  428. agent_thought.answer_price_unit = llm_usage.completion_price_unit
  429. agent_thought.answer_unit_price = llm_usage.completion_unit_price
  430. agent_thought.tokens = llm_usage.total_tokens
  431. agent_thought.total_price = llm_usage.total_price
  432. # check if tool labels is not empty
  433. labels = agent_thought.tool_labels or {}
  434. tools = agent_thought.tool.split(';') if agent_thought.tool else []
  435. for tool in tools:
  436. if not tool:
  437. continue
  438. if tool not in labels:
  439. tool_label = ToolManager.get_tool_label(tool)
  440. if tool_label:
  441. labels[tool] = tool_label.to_dict()
  442. else:
  443. labels[tool] = {'en_US': tool, 'zh_Hans': tool}
  444. agent_thought.tool_labels_str = json.dumps(labels)
  445. db.session.commit()
  446. def get_history_prompt_messages(self) -> List[PromptMessage]:
  447. """
  448. Get history prompt messages
  449. """
  450. if self.history_prompt_messages is None:
  451. self.history_prompt_messages = db.session.query(PromptMessage).filter(
  452. PromptMessage.message_id == self.message.id,
  453. ).order_by(PromptMessage.position.asc()).all()
  454. return self.history_prompt_messages
  455. def transform_tool_invoke_messages(self, messages: List[ToolInvokeMessage]) -> List[ToolInvokeMessage]:
  456. """
  457. Transform tool message into agent thought
  458. """
  459. result = []
  460. for message in messages:
  461. if message.type == ToolInvokeMessage.MessageType.TEXT:
  462. result.append(message)
  463. elif message.type == ToolInvokeMessage.MessageType.LINK:
  464. result.append(message)
  465. elif message.type == ToolInvokeMessage.MessageType.IMAGE:
  466. # try to download image
  467. try:
  468. file = ToolFileManager.create_file_by_url(user_id=self.user_id, tenant_id=self.tenant_id,
  469. conversation_id=self.message.conversation_id,
  470. file_url=message.message)
  471. url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
  472. result.append(ToolInvokeMessage(
  473. type=ToolInvokeMessage.MessageType.IMAGE_LINK,
  474. message=url,
  475. save_as=message.save_as,
  476. meta=message.meta.copy() if message.meta is not None else {},
  477. ))
  478. except Exception as e:
  479. logger.exception(e)
  480. result.append(ToolInvokeMessage(
  481. type=ToolInvokeMessage.MessageType.TEXT,
  482. message=f"Failed to download image: {message.message}, you can try to download it yourself.",
  483. meta=message.meta.copy() if message.meta is not None else {},
  484. save_as=message.save_as,
  485. ))
  486. elif message.type == ToolInvokeMessage.MessageType.BLOB:
  487. # get mime type and save blob to storage
  488. mimetype = message.meta.get('mime_type', 'octet/stream')
  489. # if message is str, encode it to bytes
  490. if isinstance(message.message, str):
  491. message.message = message.message.encode('utf-8')
  492. file = ToolFileManager.create_file_by_raw(user_id=self.user_id, tenant_id=self.tenant_id,
  493. conversation_id=self.message.conversation_id,
  494. file_binary=message.message,
  495. mimetype=mimetype)
  496. url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".bin"}'
  497. # check if file is image
  498. if 'image' in mimetype:
  499. result.append(ToolInvokeMessage(
  500. type=ToolInvokeMessage.MessageType.IMAGE_LINK,
  501. message=url,
  502. save_as=message.save_as,
  503. meta=message.meta.copy() if message.meta is not None else {},
  504. ))
  505. else:
  506. result.append(ToolInvokeMessage(
  507. type=ToolInvokeMessage.MessageType.LINK,
  508. message=url,
  509. save_as=message.save_as,
  510. meta=message.meta.copy() if message.meta is not None else {},
  511. ))
  512. else:
  513. result.append(message)
  514. return result
  515. def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
  516. """
  517. convert tool variables to db variables
  518. """
  519. db_variables.updated_at = datetime.utcnow()
  520. db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
  521. db.session.commit()