base_agent_runner.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. import json
  2. import logging
  3. import uuid
  4. from datetime import UTC, datetime
  5. from typing import Optional, Union, cast
  6. from core.agent.entities import AgentEntity, AgentToolEntity
  7. from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
  8. from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
  9. from core.app.apps.base_app_queue_manager import AppQueueManager
  10. from core.app.apps.base_app_runner import AppRunner
  11. from core.app.entities.app_invoke_entities import (
  12. AgentChatAppGenerateEntity,
  13. ModelConfigWithCredentialsEntity,
  14. )
  15. from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
  16. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  17. from core.file import file_manager
  18. from core.memory.token_buffer_memory import TokenBufferMemory
  19. from core.model_manager import ModelInstance
  20. from core.model_runtime.entities import (
  21. AssistantPromptMessage,
  22. LLMUsage,
  23. PromptMessage,
  24. PromptMessageContent,
  25. PromptMessageTool,
  26. SystemPromptMessage,
  27. TextPromptMessageContent,
  28. ToolPromptMessage,
  29. UserPromptMessage,
  30. )
  31. from core.model_runtime.entities.message_entities import ImagePromptMessageContent
  32. from core.model_runtime.entities.model_entities import ModelFeature
  33. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  34. from core.model_runtime.utils.encoders import jsonable_encoder
  35. from core.prompt.utils.extract_thread_messages import extract_thread_messages
  36. from core.tools.entities.tool_entities import (
  37. ToolParameter,
  38. ToolRuntimeVariablePool,
  39. )
  40. from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
  41. from core.tools.tool.tool import Tool
  42. from core.tools.tool_manager import ToolManager
  43. from extensions.ext_database import db
  44. from factories import file_factory
  45. from models.model import Conversation, Message, MessageAgentThought, MessageFile
  46. from models.tools import ToolConversationVariables
  47. logger = logging.getLogger(__name__)
  48. class BaseAgentRunner(AppRunner):
  49. def __init__(
  50. self,
  51. *,
  52. tenant_id: str,
  53. application_generate_entity: AgentChatAppGenerateEntity,
  54. conversation: Conversation,
  55. app_config: AgentChatAppConfig,
  56. model_config: ModelConfigWithCredentialsEntity,
  57. config: AgentEntity,
  58. queue_manager: AppQueueManager,
  59. message: Message,
  60. user_id: str,
  61. memory: Optional[TokenBufferMemory] = None,
  62. prompt_messages: Optional[list[PromptMessage]] = None,
  63. variables_pool: Optional[ToolRuntimeVariablePool] = None,
  64. db_variables: Optional[ToolConversationVariables] = None,
  65. model_instance: ModelInstance,
  66. ) -> None:
  67. self.tenant_id = tenant_id
  68. self.application_generate_entity = application_generate_entity
  69. self.conversation = conversation
  70. self.app_config = app_config
  71. self.model_config = model_config
  72. self.config = config
  73. self.queue_manager = queue_manager
  74. self.message = message
  75. self.user_id = user_id
  76. self.memory = memory
  77. self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
  78. self.variables_pool = variables_pool
  79. self.db_variables_pool = db_variables
  80. self.model_instance = model_instance
  81. # init callback
  82. self.agent_callback = DifyAgentCallbackHandler()
  83. # init dataset tools
  84. hit_callback = DatasetIndexToolCallbackHandler(
  85. queue_manager=queue_manager,
  86. app_id=self.app_config.app_id,
  87. message_id=message.id,
  88. user_id=user_id,
  89. invoke_from=self.application_generate_entity.invoke_from,
  90. )
  91. self.dataset_tools = DatasetRetrieverTool.get_dataset_tools(
  92. tenant_id=tenant_id,
  93. dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
  94. retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
  95. return_resource=app_config.additional_features.show_retrieve_source,
  96. invoke_from=application_generate_entity.invoke_from,
  97. hit_callback=hit_callback,
  98. )
  99. # get how many agent thoughts have been created
  100. self.agent_thought_count = (
  101. db.session.query(MessageAgentThought)
  102. .filter(
  103. MessageAgentThought.message_id == self.message.id,
  104. )
  105. .count()
  106. )
  107. db.session.close()
  108. # check if model supports stream tool call
  109. llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
  110. model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
  111. features = model_schema.features if model_schema and model_schema.features else []
  112. self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
  113. self.files = application_generate_entity.files if ModelFeature.VISION in features else []
  114. self.query: Optional[str] = ""
  115. self._current_thoughts: list[PromptMessage] = []
  116. def _repack_app_generate_entity(
  117. self, app_generate_entity: AgentChatAppGenerateEntity
  118. ) -> AgentChatAppGenerateEntity:
  119. """
  120. Repack app generate entity
  121. """
  122. if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
  123. app_generate_entity.app_config.prompt_template.simple_prompt_template = ""
  124. return app_generate_entity
  125. def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
  126. """
  127. convert tool to prompt message tool
  128. """
  129. tool_entity = ToolManager.get_agent_tool_runtime(
  130. tenant_id=self.tenant_id,
  131. app_id=self.app_config.app_id,
  132. agent_tool=tool,
  133. invoke_from=self.application_generate_entity.invoke_from,
  134. )
  135. tool_entity.load_variables(self.variables_pool)
  136. message_tool = PromptMessageTool(
  137. name=tool.tool_name,
  138. description=tool_entity.description.llm if tool_entity.description else "",
  139. parameters={
  140. "type": "object",
  141. "properties": {},
  142. "required": [],
  143. },
  144. )
  145. parameters = tool_entity.get_all_runtime_parameters()
  146. for parameter in parameters:
  147. if parameter.form != ToolParameter.ToolParameterForm.LLM:
  148. continue
  149. parameter_type = parameter.type.as_normal_type()
  150. if parameter.type in {
  151. ToolParameter.ToolParameterType.SYSTEM_FILES,
  152. ToolParameter.ToolParameterType.FILE,
  153. ToolParameter.ToolParameterType.FILES,
  154. }:
  155. continue
  156. enum = []
  157. if parameter.type == ToolParameter.ToolParameterType.SELECT:
  158. enum = [option.value for option in parameter.options] if parameter.options else []
  159. message_tool.parameters["properties"][parameter.name] = {
  160. "type": parameter_type,
  161. "description": parameter.llm_description or "",
  162. }
  163. if len(enum) > 0:
  164. message_tool.parameters["properties"][parameter.name]["enum"] = enum
  165. if parameter.required:
  166. message_tool.parameters["required"].append(parameter.name)
  167. return message_tool, tool_entity
  168. def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
  169. """
  170. convert dataset retriever tool to prompt message tool
  171. """
  172. prompt_tool = PromptMessageTool(
  173. name=tool.identity.name if tool.identity else "unknown",
  174. description=tool.description.llm if tool.description else "",
  175. parameters={
  176. "type": "object",
  177. "properties": {},
  178. "required": [],
  179. },
  180. )
  181. for parameter in tool.get_runtime_parameters():
  182. parameter_type = "string"
  183. prompt_tool.parameters["properties"][parameter.name] = {
  184. "type": parameter_type,
  185. "description": parameter.llm_description or "",
  186. }
  187. if parameter.required:
  188. if parameter.name not in prompt_tool.parameters["required"]:
  189. prompt_tool.parameters["required"].append(parameter.name)
  190. return prompt_tool
  191. def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
  192. """
  193. Init tools
  194. """
  195. tool_instances = {}
  196. prompt_messages_tools = []
  197. for tool in self.app_config.agent.tools or [] if self.app_config.agent else []:
  198. try:
  199. prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
  200. except Exception:
  201. # api tool may be deleted
  202. continue
  203. # save tool entity
  204. tool_instances[tool.tool_name] = tool_entity
  205. # save prompt tool
  206. prompt_messages_tools.append(prompt_tool)
  207. # convert dataset tools into ModelRuntime Tool format
  208. for dataset_tool in self.dataset_tools:
  209. prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
  210. # save prompt tool
  211. prompt_messages_tools.append(prompt_tool)
  212. # save tool entity
  213. if dataset_tool.identity is not None:
  214. tool_instances[dataset_tool.identity.name] = dataset_tool
  215. return tool_instances, prompt_messages_tools
  216. def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
  217. """
  218. update prompt message tool
  219. """
  220. # try to get tool runtime parameters
  221. tool_runtime_parameters = tool.get_runtime_parameters()
  222. for parameter in tool_runtime_parameters:
  223. if parameter.form != ToolParameter.ToolParameterForm.LLM:
  224. continue
  225. parameter_type = parameter.type.as_normal_type()
  226. if parameter.type in {
  227. ToolParameter.ToolParameterType.SYSTEM_FILES,
  228. ToolParameter.ToolParameterType.FILE,
  229. ToolParameter.ToolParameterType.FILES,
  230. }:
  231. continue
  232. enum = []
  233. if parameter.type == ToolParameter.ToolParameterType.SELECT:
  234. enum = [option.value for option in parameter.options] if parameter.options else []
  235. prompt_tool.parameters["properties"][parameter.name] = {
  236. "type": parameter_type,
  237. "description": parameter.llm_description or "",
  238. }
  239. if len(enum) > 0:
  240. prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
  241. if parameter.required:
  242. if parameter.name not in prompt_tool.parameters["required"]:
  243. prompt_tool.parameters["required"].append(parameter.name)
  244. return prompt_tool
  245. def create_agent_thought(
  246. self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str]
  247. ) -> MessageAgentThought:
  248. """
  249. Create agent thought
  250. """
  251. thought = MessageAgentThought(
  252. message_id=message_id,
  253. message_chain_id=None,
  254. thought="",
  255. tool=tool_name,
  256. tool_labels_str="{}",
  257. tool_meta_str="{}",
  258. tool_input=tool_input,
  259. message=message,
  260. message_token=0,
  261. message_unit_price=0,
  262. message_price_unit=0,
  263. message_files=json.dumps(messages_ids) if messages_ids else "",
  264. answer="",
  265. observation="",
  266. answer_token=0,
  267. answer_unit_price=0,
  268. answer_price_unit=0,
  269. tokens=0,
  270. total_price=0,
  271. position=self.agent_thought_count + 1,
  272. currency="USD",
  273. latency=0,
  274. created_by_role="account",
  275. created_by=self.user_id,
  276. )
  277. db.session.add(thought)
  278. db.session.commit()
  279. db.session.refresh(thought)
  280. db.session.close()
  281. self.agent_thought_count += 1
  282. return thought
  283. def save_agent_thought(
  284. self,
  285. agent_thought: MessageAgentThought,
  286. tool_name: str,
  287. tool_input: Union[str, dict],
  288. thought: str,
  289. observation: Union[str, dict, None],
  290. tool_invoke_meta: Union[str, dict, None],
  291. answer: str,
  292. messages_ids: list[str],
  293. llm_usage: LLMUsage | None = None,
  294. ):
  295. """
  296. Save agent thought
  297. """
  298. queried_thought = (
  299. db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
  300. )
  301. if not queried_thought:
  302. raise ValueError(f"Agent thought {agent_thought.id} not found")
  303. agent_thought = queried_thought
  304. if thought:
  305. agent_thought.thought = thought
  306. if tool_name:
  307. agent_thought.tool = tool_name
  308. if tool_input:
  309. if isinstance(tool_input, dict):
  310. try:
  311. tool_input = json.dumps(tool_input, ensure_ascii=False)
  312. except Exception as e:
  313. tool_input = json.dumps(tool_input)
  314. agent_thought.tool_input = tool_input
  315. if observation:
  316. if isinstance(observation, dict):
  317. try:
  318. observation = json.dumps(observation, ensure_ascii=False)
  319. except Exception as e:
  320. observation = json.dumps(observation)
  321. agent_thought.observation = observation
  322. if answer:
  323. agent_thought.answer = answer
  324. if messages_ids is not None and len(messages_ids) > 0:
  325. agent_thought.message_files = json.dumps(messages_ids)
  326. if llm_usage:
  327. agent_thought.message_token = llm_usage.prompt_tokens
  328. agent_thought.message_price_unit = llm_usage.prompt_price_unit
  329. agent_thought.message_unit_price = llm_usage.prompt_unit_price
  330. agent_thought.answer_token = llm_usage.completion_tokens
  331. agent_thought.answer_price_unit = llm_usage.completion_price_unit
  332. agent_thought.answer_unit_price = llm_usage.completion_unit_price
  333. agent_thought.tokens = llm_usage.total_tokens
  334. agent_thought.total_price = llm_usage.total_price
  335. # check if tool labels is not empty
  336. labels = agent_thought.tool_labels or {}
  337. tools = agent_thought.tool.split(";") if agent_thought.tool else []
  338. for tool in tools:
  339. if not tool:
  340. continue
  341. if tool not in labels:
  342. tool_label = ToolManager.get_tool_label(tool)
  343. if tool_label:
  344. labels[tool] = tool_label.to_dict()
  345. else:
  346. labels[tool] = {"en_US": tool, "zh_Hans": tool}
  347. agent_thought.tool_labels_str = json.dumps(labels)
  348. if tool_invoke_meta is not None:
  349. if isinstance(tool_invoke_meta, dict):
  350. try:
  351. tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False)
  352. except Exception as e:
  353. tool_invoke_meta = json.dumps(tool_invoke_meta)
  354. agent_thought.tool_meta_str = tool_invoke_meta
  355. db.session.commit()
  356. db.session.close()
  357. def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
  358. """
  359. convert tool variables to db variables
  360. """
  361. queried_variables = (
  362. db.session.query(ToolConversationVariables)
  363. .filter(
  364. ToolConversationVariables.conversation_id == self.message.conversation_id,
  365. )
  366. .first()
  367. )
  368. if not queried_variables:
  369. return
  370. db_variables = queried_variables
  371. db_variables.updated_at = datetime.now(UTC).replace(tzinfo=None)
  372. db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
  373. db.session.commit()
  374. db.session.close()
  375. def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  376. """
  377. Organize agent history
  378. """
  379. result: list[PromptMessage] = []
  380. # check if there is a system message in the beginning of the conversation
  381. for prompt_message in prompt_messages:
  382. if isinstance(prompt_message, SystemPromptMessage):
  383. result.append(prompt_message)
  384. messages: list[Message] = (
  385. db.session.query(Message)
  386. .filter(
  387. Message.conversation_id == self.message.conversation_id,
  388. )
  389. .order_by(Message.created_at.desc())
  390. .all()
  391. )
  392. messages = list(reversed(extract_thread_messages(messages)))
  393. for message in messages:
  394. if message.id == self.message.id:
  395. continue
  396. result.append(self.organize_agent_user_prompt(message))
  397. agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
  398. if agent_thoughts:
  399. for agent_thought in agent_thoughts:
  400. tools = agent_thought.tool
  401. if tools:
  402. tools = tools.split(";")
  403. tool_calls: list[AssistantPromptMessage.ToolCall] = []
  404. tool_call_response: list[ToolPromptMessage] = []
  405. try:
  406. tool_inputs = json.loads(agent_thought.tool_input)
  407. except Exception as e:
  408. tool_inputs = {tool: {} for tool in tools}
  409. try:
  410. tool_responses = json.loads(agent_thought.observation)
  411. except Exception as e:
  412. tool_responses = dict.fromkeys(tools, agent_thought.observation)
  413. for tool in tools:
  414. # generate a uuid for tool call
  415. tool_call_id = str(uuid.uuid4())
  416. tool_calls.append(
  417. AssistantPromptMessage.ToolCall(
  418. id=tool_call_id,
  419. type="function",
  420. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  421. name=tool,
  422. arguments=json.dumps(tool_inputs.get(tool, {})),
  423. ),
  424. )
  425. )
  426. tool_call_response.append(
  427. ToolPromptMessage(
  428. content=tool_responses.get(tool, agent_thought.observation),
  429. name=tool,
  430. tool_call_id=tool_call_id,
  431. )
  432. )
  433. result.extend(
  434. [
  435. AssistantPromptMessage(
  436. content=agent_thought.thought,
  437. tool_calls=tool_calls,
  438. ),
  439. *tool_call_response,
  440. ]
  441. )
  442. if not tools:
  443. result.append(AssistantPromptMessage(content=agent_thought.thought))
  444. else:
  445. if message.answer:
  446. result.append(AssistantPromptMessage(content=message.answer))
  447. db.session.close()
  448. return result
  449. def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
  450. files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
  451. if not files:
  452. return UserPromptMessage(content=message.query)
  453. file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
  454. if not file_extra_config:
  455. return UserPromptMessage(content=message.query)
  456. image_detail_config = file_extra_config.image_config.detail if file_extra_config.image_config else None
  457. image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
  458. file_objs = file_factory.build_from_message_files(
  459. message_files=files, tenant_id=self.tenant_id, config=file_extra_config
  460. )
  461. if not file_objs:
  462. return UserPromptMessage(content=message.query)
  463. prompt_message_contents: list[PromptMessageContent] = []
  464. prompt_message_contents.append(TextPromptMessageContent(data=message.query))
  465. for file in file_objs:
  466. prompt_message_contents.append(
  467. file_manager.to_prompt_message_content(
  468. file,
  469. image_detail_config=image_detail_config,
  470. )
  471. )
  472. return UserPromptMessage(content=prompt_message_contents)