base_agent_runner.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513
  1. import json
  2. import logging
  3. import uuid
  4. from typing import Optional, Union, cast
  5. from core.agent.entities import AgentEntity, AgentToolEntity
  6. from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
  7. from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
  8. from core.app.apps.base_app_queue_manager import AppQueueManager
  9. from core.app.apps.base_app_runner import AppRunner
  10. from core.app.entities.app_invoke_entities import (
  11. AgentChatAppGenerateEntity,
  12. ModelConfigWithCredentialsEntity,
  13. )
  14. from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
  15. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  16. from core.file import file_manager
  17. from core.memory.token_buffer_memory import TokenBufferMemory
  18. from core.model_manager import ModelInstance
  19. from core.model_runtime.entities import (
  20. AssistantPromptMessage,
  21. LLMUsage,
  22. PromptMessage,
  23. PromptMessageContent,
  24. PromptMessageTool,
  25. SystemPromptMessage,
  26. TextPromptMessageContent,
  27. ToolPromptMessage,
  28. UserPromptMessage,
  29. )
  30. from core.model_runtime.entities.message_entities import ImagePromptMessageContent
  31. from core.model_runtime.entities.model_entities import ModelFeature
  32. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  33. from core.prompt.utils.extract_thread_messages import extract_thread_messages
  34. from core.tools.__base.tool import Tool
  35. from core.tools.entities.tool_entities import (
  36. ToolParameter,
  37. )
  38. from core.tools.tool_manager import ToolManager
  39. from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
  40. from extensions.ext_database import db
  41. from factories import file_factory
  42. from models.model import Conversation, Message, MessageAgentThought, MessageFile
  43. logger = logging.getLogger(__name__)
  44. class BaseAgentRunner(AppRunner):
  45. def __init__(
  46. self,
  47. *,
  48. tenant_id: str,
  49. application_generate_entity: AgentChatAppGenerateEntity,
  50. conversation: Conversation,
  51. app_config: AgentChatAppConfig,
  52. model_config: ModelConfigWithCredentialsEntity,
  53. config: AgentEntity,
  54. queue_manager: AppQueueManager,
  55. message: Message,
  56. user_id: str,
  57. model_instance: ModelInstance,
  58. memory: Optional[TokenBufferMemory] = None,
  59. prompt_messages: Optional[list[PromptMessage]] = None,
  60. ) -> None:
  61. self.tenant_id = tenant_id
  62. self.application_generate_entity = application_generate_entity
  63. self.conversation = conversation
  64. self.app_config = app_config
  65. self.model_config = model_config
  66. self.config = config
  67. self.queue_manager = queue_manager
  68. self.message = message
  69. self.user_id = user_id
  70. self.memory = memory
  71. self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
  72. self.model_instance = model_instance
  73. # init callback
  74. self.agent_callback = DifyAgentCallbackHandler()
  75. # init dataset tools
  76. hit_callback = DatasetIndexToolCallbackHandler(
  77. queue_manager=queue_manager,
  78. app_id=self.app_config.app_id,
  79. message_id=message.id,
  80. user_id=user_id,
  81. invoke_from=self.application_generate_entity.invoke_from,
  82. )
  83. self.dataset_tools = DatasetRetrieverTool.get_dataset_tools(
  84. tenant_id=tenant_id,
  85. dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
  86. retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
  87. return_resource=app_config.additional_features.show_retrieve_source,
  88. invoke_from=application_generate_entity.invoke_from,
  89. hit_callback=hit_callback,
  90. )
  91. # get how many agent thoughts have been created
  92. self.agent_thought_count = (
  93. db.session.query(MessageAgentThought)
  94. .filter(
  95. MessageAgentThought.message_id == self.message.id,
  96. )
  97. .count()
  98. )
  99. db.session.close()
  100. # check if model supports stream tool call
  101. llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
  102. model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
  103. features = model_schema.features if model_schema and model_schema.features else []
  104. self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
  105. self.files = application_generate_entity.files if ModelFeature.VISION in features else []
  106. self.query: Optional[str] = ""
  107. self._current_thoughts: list[PromptMessage] = []
  108. def _repack_app_generate_entity(
  109. self, app_generate_entity: AgentChatAppGenerateEntity
  110. ) -> AgentChatAppGenerateEntity:
  111. """
  112. Repack app generate entity
  113. """
  114. if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
  115. app_generate_entity.app_config.prompt_template.simple_prompt_template = ""
  116. return app_generate_entity
  117. def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
  118. """
  119. convert tool to prompt message tool
  120. """
  121. tool_entity = ToolManager.get_agent_tool_runtime(
  122. tenant_id=self.tenant_id,
  123. app_id=self.app_config.app_id,
  124. agent_tool=tool,
  125. invoke_from=self.application_generate_entity.invoke_from,
  126. )
  127. assert tool_entity.entity.description
  128. message_tool = PromptMessageTool(
  129. name=tool.tool_name,
  130. description=tool_entity.entity.description.llm,
  131. parameters={
  132. "type": "object",
  133. "properties": {},
  134. "required": [],
  135. },
  136. )
  137. parameters = tool_entity.get_merged_runtime_parameters()
  138. for parameter in parameters:
  139. if parameter.form != ToolParameter.ToolParameterForm.LLM:
  140. continue
  141. parameter_type = parameter.type.as_normal_type()
  142. if parameter.type in {
  143. ToolParameter.ToolParameterType.SYSTEM_FILES,
  144. ToolParameter.ToolParameterType.FILE,
  145. ToolParameter.ToolParameterType.FILES,
  146. }:
  147. continue
  148. enum = []
  149. if parameter.type == ToolParameter.ToolParameterType.SELECT:
  150. enum = [option.value for option in parameter.options] if parameter.options else []
  151. message_tool.parameters["properties"][parameter.name] = {
  152. "type": parameter_type,
  153. "description": parameter.llm_description or "",
  154. }
  155. if len(enum) > 0:
  156. message_tool.parameters["properties"][parameter.name]["enum"] = enum
  157. if parameter.required:
  158. message_tool.parameters["required"].append(parameter.name)
  159. return message_tool, tool_entity
  160. def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
  161. """
  162. convert dataset retriever tool to prompt message tool
  163. """
  164. assert tool.entity.description
  165. prompt_tool = PromptMessageTool(
  166. name=tool.entity.identity.name,
  167. description=tool.entity.description.llm,
  168. parameters={
  169. "type": "object",
  170. "properties": {},
  171. "required": [],
  172. },
  173. )
  174. for parameter in tool.get_runtime_parameters():
  175. parameter_type = "string"
  176. prompt_tool.parameters["properties"][parameter.name] = {
  177. "type": parameter_type,
  178. "description": parameter.llm_description or "",
  179. }
  180. if parameter.required:
  181. if parameter.name not in prompt_tool.parameters["required"]:
  182. prompt_tool.parameters["required"].append(parameter.name)
  183. return prompt_tool
  184. def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
  185. """
  186. Init tools
  187. """
  188. tool_instances = {}
  189. prompt_messages_tools = []
  190. for tool in self.app_config.agent.tools or [] if self.app_config.agent else []:
  191. try:
  192. prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
  193. except Exception:
  194. # api tool may be deleted
  195. continue
  196. # save tool entity
  197. tool_instances[tool.tool_name] = tool_entity
  198. # save prompt tool
  199. prompt_messages_tools.append(prompt_tool)
  200. # convert dataset tools into ModelRuntime Tool format
  201. for dataset_tool in self.dataset_tools:
  202. prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
  203. # save prompt tool
  204. prompt_messages_tools.append(prompt_tool)
  205. # save tool entity
  206. tool_instances[dataset_tool.entity.identity.name] = dataset_tool
  207. return tool_instances, prompt_messages_tools
  208. def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
  209. """
  210. update prompt message tool
  211. """
  212. # try to get tool runtime parameters
  213. tool_runtime_parameters = tool.get_runtime_parameters()
  214. for parameter in tool_runtime_parameters:
  215. if parameter.form != ToolParameter.ToolParameterForm.LLM:
  216. continue
  217. parameter_type = parameter.type.as_normal_type()
  218. if parameter.type in {
  219. ToolParameter.ToolParameterType.SYSTEM_FILES,
  220. ToolParameter.ToolParameterType.FILE,
  221. ToolParameter.ToolParameterType.FILES,
  222. }:
  223. continue
  224. enum = []
  225. if parameter.type == ToolParameter.ToolParameterType.SELECT:
  226. enum = [option.value for option in parameter.options] if parameter.options else []
  227. prompt_tool.parameters["properties"][parameter.name] = {
  228. "type": parameter_type,
  229. "description": parameter.llm_description or "",
  230. }
  231. if len(enum) > 0:
  232. prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
  233. if parameter.required:
  234. if parameter.name not in prompt_tool.parameters["required"]:
  235. prompt_tool.parameters["required"].append(parameter.name)
  236. return prompt_tool
  237. def create_agent_thought(
  238. self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str]
  239. ) -> MessageAgentThought:
  240. """
  241. Create agent thought
  242. """
  243. thought = MessageAgentThought(
  244. message_id=message_id,
  245. message_chain_id=None,
  246. thought="",
  247. tool=tool_name,
  248. tool_labels_str="{}",
  249. tool_meta_str="{}",
  250. tool_input=tool_input,
  251. message=message,
  252. message_token=0,
  253. message_unit_price=0,
  254. message_price_unit=0,
  255. message_files=json.dumps(messages_ids) if messages_ids else "",
  256. answer="",
  257. observation="",
  258. answer_token=0,
  259. answer_unit_price=0,
  260. answer_price_unit=0,
  261. tokens=0,
  262. total_price=0,
  263. position=self.agent_thought_count + 1,
  264. currency="USD",
  265. latency=0,
  266. created_by_role="account",
  267. created_by=self.user_id,
  268. )
  269. db.session.add(thought)
  270. db.session.commit()
  271. db.session.refresh(thought)
  272. db.session.close()
  273. self.agent_thought_count += 1
  274. return thought
  275. def save_agent_thought(
  276. self,
  277. agent_thought: MessageAgentThought,
  278. tool_name: str | None,
  279. tool_input: Union[str, dict, None],
  280. thought: str | None,
  281. observation: Union[str, dict, None],
  282. tool_invoke_meta: Union[str, dict, None],
  283. answer: str | None,
  284. messages_ids: list[str],
  285. llm_usage: LLMUsage | None = None,
  286. ):
  287. """
  288. Save agent thought
  289. """
  290. updated_agent_thought = (
  291. db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
  292. )
  293. if not updated_agent_thought:
  294. raise ValueError("agent thought not found")
  295. agent_thought = updated_agent_thought
  296. if thought:
  297. agent_thought.thought = thought
  298. if tool_name:
  299. agent_thought.tool = tool_name
  300. if tool_input:
  301. if isinstance(tool_input, dict):
  302. try:
  303. tool_input = json.dumps(tool_input, ensure_ascii=False)
  304. except Exception:
  305. tool_input = json.dumps(tool_input)
  306. updated_agent_thought.tool_input = tool_input
  307. if observation:
  308. if isinstance(observation, dict):
  309. try:
  310. observation = json.dumps(observation, ensure_ascii=False)
  311. except Exception:
  312. observation = json.dumps(observation)
  313. updated_agent_thought.observation = observation
  314. if answer:
  315. agent_thought.answer = answer
  316. if messages_ids is not None and len(messages_ids) > 0:
  317. updated_agent_thought.message_files = json.dumps(messages_ids)
  318. if llm_usage:
  319. updated_agent_thought.message_token = llm_usage.prompt_tokens
  320. updated_agent_thought.message_price_unit = llm_usage.prompt_price_unit
  321. updated_agent_thought.message_unit_price = llm_usage.prompt_unit_price
  322. updated_agent_thought.answer_token = llm_usage.completion_tokens
  323. updated_agent_thought.answer_price_unit = llm_usage.completion_price_unit
  324. updated_agent_thought.answer_unit_price = llm_usage.completion_unit_price
  325. updated_agent_thought.tokens = llm_usage.total_tokens
  326. updated_agent_thought.total_price = llm_usage.total_price
  327. # check if tool labels is not empty
  328. labels = updated_agent_thought.tool_labels or {}
  329. tools = updated_agent_thought.tool.split(";") if updated_agent_thought.tool else []
  330. for tool in tools:
  331. if not tool:
  332. continue
  333. if tool not in labels:
  334. tool_label = ToolManager.get_tool_label(tool)
  335. if tool_label:
  336. labels[tool] = tool_label.to_dict()
  337. else:
  338. labels[tool] = {"en_US": tool, "zh_Hans": tool}
  339. updated_agent_thought.tool_labels_str = json.dumps(labels)
  340. if tool_invoke_meta is not None:
  341. if isinstance(tool_invoke_meta, dict):
  342. try:
  343. tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False)
  344. except Exception:
  345. tool_invoke_meta = json.dumps(tool_invoke_meta)
  346. updated_agent_thought.tool_meta_str = tool_invoke_meta
  347. db.session.commit()
  348. db.session.close()
  349. def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  350. """
  351. Organize agent history
  352. """
  353. result: list[PromptMessage] = []
  354. # check if there is a system message in the beginning of the conversation
  355. for prompt_message in prompt_messages:
  356. if isinstance(prompt_message, SystemPromptMessage):
  357. result.append(prompt_message)
  358. messages: list[Message] = (
  359. db.session.query(Message)
  360. .filter(
  361. Message.conversation_id == self.message.conversation_id,
  362. )
  363. .order_by(Message.created_at.desc())
  364. .all()
  365. )
  366. messages = list(reversed(extract_thread_messages(messages)))
  367. for message in messages:
  368. if message.id == self.message.id:
  369. continue
  370. result.append(self.organize_agent_user_prompt(message))
  371. agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
  372. if agent_thoughts:
  373. for agent_thought in agent_thoughts:
  374. tools = agent_thought.tool
  375. if tools:
  376. tools = tools.split(";")
  377. tool_calls: list[AssistantPromptMessage.ToolCall] = []
  378. tool_call_response: list[ToolPromptMessage] = []
  379. try:
  380. tool_inputs = json.loads(agent_thought.tool_input)
  381. except Exception:
  382. tool_inputs = {tool: {} for tool in tools}
  383. try:
  384. tool_responses = json.loads(agent_thought.observation)
  385. except Exception:
  386. tool_responses = dict.fromkeys(tools, agent_thought.observation)
  387. for tool in tools:
  388. # generate a uuid for tool call
  389. tool_call_id = str(uuid.uuid4())
  390. tool_calls.append(
  391. AssistantPromptMessage.ToolCall(
  392. id=tool_call_id,
  393. type="function",
  394. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  395. name=tool,
  396. arguments=json.dumps(tool_inputs.get(tool, {})),
  397. ),
  398. )
  399. )
  400. tool_call_response.append(
  401. ToolPromptMessage(
  402. content=tool_responses.get(tool, agent_thought.observation),
  403. name=tool,
  404. tool_call_id=tool_call_id,
  405. )
  406. )
  407. result.extend(
  408. [
  409. AssistantPromptMessage(
  410. content=agent_thought.thought,
  411. tool_calls=tool_calls,
  412. ),
  413. *tool_call_response,
  414. ]
  415. )
  416. if not tools:
  417. result.append(AssistantPromptMessage(content=agent_thought.thought))
  418. else:
  419. if message.answer:
  420. result.append(AssistantPromptMessage(content=message.answer))
  421. db.session.close()
  422. return result
  423. def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
  424. files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
  425. if not files:
  426. return UserPromptMessage(content=message.query)
  427. if message.app_model_config:
  428. file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
  429. else:
  430. file_extra_config = None
  431. if not file_extra_config:
  432. return UserPromptMessage(content=message.query)
  433. image_detail_config = file_extra_config.image_config.detail if file_extra_config.image_config else None
  434. image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
  435. file_objs = file_factory.build_from_message_files(
  436. message_files=files, tenant_id=self.tenant_id, config=file_extra_config
  437. )
  438. if not file_objs:
  439. return UserPromptMessage(content=message.query)
  440. prompt_message_contents: list[PromptMessageContent] = []
  441. prompt_message_contents.append(TextPromptMessageContent(data=message.query))
  442. for file in file_objs:
  443. prompt_message_contents.append(
  444. file_manager.to_prompt_message_content(
  445. file,
  446. image_detail_config=image_detail_config,
  447. )
  448. )
  449. return UserPromptMessage(content=prompt_message_contents)