fc_agent_runner.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473
  1. import json
  2. import logging
  3. from collections.abc import Generator
  4. from copy import deepcopy
  5. from typing import Any, Optional, Union
  6. from core.agent.base_agent_runner import BaseAgentRunner
  7. from core.app.apps.base_app_queue_manager import PublishFrom
  8. from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
  9. from core.file import file_manager
  10. from core.model_runtime.entities import (
  11. AssistantPromptMessage,
  12. LLMResult,
  13. LLMResultChunk,
  14. LLMResultChunkDelta,
  15. LLMUsage,
  16. PromptMessage,
  17. PromptMessageContent,
  18. PromptMessageContentType,
  19. SystemPromptMessage,
  20. TextPromptMessageContent,
  21. ToolPromptMessage,
  22. UserPromptMessage,
  23. )
  24. from core.model_runtime.entities.message_entities import ImagePromptMessageContent
  25. from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
  26. from core.tools.entities.tool_entities import ToolInvokeMeta
  27. from core.tools.tool_engine import ToolEngine
  28. from models.model import Message
  29. logger = logging.getLogger(__name__)
  30. class FunctionCallAgentRunner(BaseAgentRunner):
  31. def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
  32. """
  33. Run FunctionCall agent application
  34. """
  35. self.query = query
  36. app_generate_entity = self.application_generate_entity
  37. app_config = self.app_config
  38. assert app_config is not None, "app_config is required"
  39. assert app_config.agent is not None, "app_config.agent is required"
  40. # convert tools into ModelRuntime Tool format
  41. tool_instances, prompt_messages_tools = self._init_prompt_tools()
  42. iteration_step = 1
  43. max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
  44. # continue to run until there is not any tool call
  45. function_call_state = True
  46. llm_usage: dict[str, LLMUsage] = {"usage": LLMUsage.empty_usage()}
  47. final_answer = ""
  48. # get tracing instance
  49. trace_manager = app_generate_entity.trace_manager
  50. def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
  51. if not final_llm_usage_dict["usage"]:
  52. final_llm_usage_dict["usage"] = usage
  53. else:
  54. llm_usage = final_llm_usage_dict["usage"]
  55. llm_usage.prompt_tokens += usage.prompt_tokens
  56. llm_usage.completion_tokens += usage.completion_tokens
  57. llm_usage.prompt_price += usage.prompt_price
  58. llm_usage.completion_price += usage.completion_price
  59. llm_usage.total_price += usage.total_price
  60. model_instance = self.model_instance
  61. while function_call_state and iteration_step <= max_iteration_steps:
  62. function_call_state = False
  63. if iteration_step == max_iteration_steps:
  64. # the last iteration, remove all tools
  65. prompt_messages_tools = []
  66. message_file_ids: list[str] = []
  67. agent_thought = self.create_agent_thought(
  68. message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
  69. )
  70. # recalc llm max tokens
  71. prompt_messages = self._organize_prompt_messages()
  72. self.recalc_llm_max_tokens(self.model_config, prompt_messages)
  73. # invoke model
  74. chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
  75. prompt_messages=prompt_messages,
  76. model_parameters=app_generate_entity.model_conf.parameters,
  77. tools=prompt_messages_tools,
  78. stop=app_generate_entity.model_conf.stop,
  79. stream=self.stream_tool_call,
  80. user=self.user_id,
  81. callbacks=[],
  82. )
  83. tool_calls: list[tuple[str, str, dict[str, Any]]] = []
  84. # save full response
  85. response = ""
  86. # save tool call names and inputs
  87. tool_call_names = ""
  88. tool_call_inputs = ""
  89. current_llm_usage = None
  90. if self.stream_tool_call and isinstance(chunks, Generator):
  91. is_first_chunk = True
  92. for chunk in chunks:
  93. if is_first_chunk:
  94. self.queue_manager.publish(
  95. QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
  96. )
  97. is_first_chunk = False
  98. # check if there is any tool call
  99. if self.check_tool_calls(chunk):
  100. function_call_state = True
  101. tool_calls.extend(self.extract_tool_calls(chunk) or [])
  102. tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
  103. try:
  104. tool_call_inputs = json.dumps(
  105. {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
  106. )
  107. except json.JSONDecodeError as e:
  108. # ensure ascii to avoid encoding error
  109. tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
  110. if chunk.delta.message and chunk.delta.message.content:
  111. if isinstance(chunk.delta.message.content, list):
  112. for content in chunk.delta.message.content:
  113. response += content.data
  114. else:
  115. response += str(chunk.delta.message.content)
  116. if chunk.delta.usage:
  117. increase_usage(llm_usage, chunk.delta.usage)
  118. current_llm_usage = chunk.delta.usage
  119. yield chunk
  120. elif not self.stream_tool_call and isinstance(chunks, LLMResult):
  121. result = chunks
  122. # check if there is any tool call
  123. if self.check_blocking_tool_calls(result):
  124. function_call_state = True
  125. tool_calls.extend(self.extract_blocking_tool_calls(result) or [])
  126. tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
  127. try:
  128. tool_call_inputs = json.dumps(
  129. {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
  130. )
  131. except json.JSONDecodeError as e:
  132. # ensure ascii to avoid encoding error
  133. tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
  134. if result.usage:
  135. increase_usage(llm_usage, result.usage)
  136. current_llm_usage = result.usage
  137. if result.message and result.message.content:
  138. if isinstance(result.message.content, list):
  139. for content in result.message.content:
  140. response += content.data
  141. else:
  142. response += str(result.message.content)
  143. if not result.message.content:
  144. result.message.content = ""
  145. self.queue_manager.publish(
  146. QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
  147. )
  148. yield LLMResultChunk(
  149. model=model_instance.model,
  150. prompt_messages=result.prompt_messages,
  151. system_fingerprint=result.system_fingerprint,
  152. delta=LLMResultChunkDelta(
  153. index=0,
  154. message=result.message,
  155. usage=result.usage,
  156. ),
  157. )
  158. else:
  159. raise RuntimeError(f"invalid chunks type: {type(chunks)}")
  160. assistant_message = AssistantPromptMessage(content="", tool_calls=[])
  161. if tool_calls:
  162. assistant_message.tool_calls = [
  163. AssistantPromptMessage.ToolCall(
  164. id=tool_call[0],
  165. type="function",
  166. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  167. name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False)
  168. ),
  169. )
  170. for tool_call in tool_calls
  171. ]
  172. else:
  173. assistant_message.content = response
  174. self._current_thoughts.append(assistant_message)
  175. # save thought
  176. self.save_agent_thought(
  177. agent_thought=agent_thought,
  178. tool_name=tool_call_names,
  179. tool_input=tool_call_inputs,
  180. thought=response,
  181. tool_invoke_meta=None,
  182. observation=None,
  183. answer=response,
  184. messages_ids=[],
  185. llm_usage=current_llm_usage,
  186. )
  187. self.queue_manager.publish(
  188. QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
  189. )
  190. final_answer += response + "\n"
  191. # call tools
  192. tool_responses = []
  193. for tool_call_id, tool_call_name, tool_call_args in tool_calls:
  194. tool_instance = tool_instances.get(tool_call_name)
  195. if not tool_instance:
  196. tool_response = {
  197. "tool_call_id": tool_call_id,
  198. "tool_call_name": tool_call_name,
  199. "tool_response": f"there is not a tool named {tool_call_name}",
  200. "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
  201. }
  202. else:
  203. # invoke tool
  204. tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
  205. tool=tool_instance,
  206. tool_parameters=tool_call_args,
  207. user_id=self.user_id,
  208. tenant_id=self.tenant_id,
  209. message=self.message,
  210. invoke_from=self.application_generate_entity.invoke_from,
  211. agent_tool_callback=self.agent_callback,
  212. trace_manager=trace_manager,
  213. )
  214. # publish files
  215. for message_file_id, save_as in message_files:
  216. if save_as:
  217. if self.variables_pool:
  218. self.variables_pool.set_file(
  219. tool_name=tool_call_name, value=message_file_id, name=save_as
  220. )
  221. # publish message file
  222. self.queue_manager.publish(
  223. QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
  224. )
  225. # add message file ids
  226. message_file_ids.append(message_file_id)
  227. tool_response = {
  228. "tool_call_id": tool_call_id,
  229. "tool_call_name": tool_call_name,
  230. "tool_response": tool_invoke_response,
  231. "meta": tool_invoke_meta.to_dict(),
  232. }
  233. tool_responses.append(tool_response)
  234. if tool_response["tool_response"] is not None:
  235. self._current_thoughts.append(
  236. ToolPromptMessage(
  237. content=str(tool_response["tool_response"]),
  238. tool_call_id=tool_call_id,
  239. name=tool_call_name,
  240. )
  241. )
  242. if len(tool_responses) > 0:
  243. # save agent thought
  244. self.save_agent_thought(
  245. agent_thought=agent_thought,
  246. tool_name="",
  247. tool_input="",
  248. thought="",
  249. tool_invoke_meta={
  250. tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
  251. },
  252. observation={
  253. tool_response["tool_call_name"]: tool_response["tool_response"]
  254. for tool_response in tool_responses
  255. },
  256. answer="",
  257. messages_ids=message_file_ids,
  258. )
  259. self.queue_manager.publish(
  260. QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
  261. )
  262. # update prompt tool
  263. for prompt_tool in prompt_messages_tools:
  264. self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
  265. iteration_step += 1
  266. if self.variables_pool and self.db_variables_pool:
  267. self.update_db_variables(self.variables_pool, self.db_variables_pool)
  268. # publish end event
  269. self.queue_manager.publish(
  270. QueueMessageEndEvent(
  271. llm_result=LLMResult(
  272. model=model_instance.model,
  273. prompt_messages=prompt_messages,
  274. message=AssistantPromptMessage(content=final_answer),
  275. usage=llm_usage["usage"] or LLMUsage.empty_usage(),
  276. system_fingerprint="",
  277. )
  278. ),
  279. PublishFrom.APPLICATION_MANAGER,
  280. )
  281. def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
  282. """
  283. Check if there is any tool call in llm result chunk
  284. """
  285. if llm_result_chunk.delta.message.tool_calls:
  286. return True
  287. return False
  288. def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
  289. """
  290. Check if there is any blocking tool call in llm result
  291. """
  292. if llm_result.message.tool_calls:
  293. return True
  294. return False
  295. def extract_tool_calls(
  296. self, llm_result_chunk: LLMResultChunk
  297. ) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
  298. """
  299. Extract tool calls from llm result chunk
  300. Returns:
  301. List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
  302. """
  303. tool_calls = []
  304. for prompt_message in llm_result_chunk.delta.message.tool_calls:
  305. args = {}
  306. if prompt_message.function.arguments != "":
  307. args = json.loads(prompt_message.function.arguments)
  308. tool_calls.append(
  309. (
  310. prompt_message.id,
  311. prompt_message.function.name,
  312. args,
  313. )
  314. )
  315. return tool_calls
  316. def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
  317. """
  318. Extract blocking tool calls from llm result
  319. Returns:
  320. List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
  321. """
  322. tool_calls = []
  323. for prompt_message in llm_result.message.tool_calls:
  324. args = {}
  325. if prompt_message.function.arguments != "":
  326. args = json.loads(prompt_message.function.arguments)
  327. tool_calls.append(
  328. (
  329. prompt_message.id,
  330. prompt_message.function.name,
  331. args,
  332. )
  333. )
  334. return tool_calls
  335. def _init_system_message(
  336. self, prompt_template: str, prompt_messages: Optional[list[PromptMessage]] = None
  337. ) -> list[PromptMessage]:
  338. """
  339. Initialize system message
  340. """
  341. if not prompt_messages and prompt_template:
  342. return [
  343. SystemPromptMessage(content=prompt_template),
  344. ]
  345. if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
  346. prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
  347. return prompt_messages or []
  348. def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  349. """
  350. Organize user query
  351. """
  352. if self.files:
  353. prompt_message_contents: list[PromptMessageContent] = []
  354. prompt_message_contents.append(TextPromptMessageContent(data=query))
  355. # get image detail config
  356. image_detail_config = (
  357. self.application_generate_entity.file_upload_config.image_config.detail
  358. if (
  359. self.application_generate_entity.file_upload_config
  360. and self.application_generate_entity.file_upload_config.image_config
  361. )
  362. else None
  363. )
  364. image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
  365. for file in self.files:
  366. prompt_message_contents.append(
  367. file_manager.to_prompt_message_content(
  368. file,
  369. image_detail_config=image_detail_config,
  370. )
  371. )
  372. prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
  373. else:
  374. prompt_messages.append(UserPromptMessage(content=query))
  375. return prompt_messages
  376. def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  377. """
  378. As for now, gpt supports both fc and vision at the first iteration.
  379. We need to remove the image messages from the prompt messages at the first iteration.
  380. """
  381. prompt_messages = deepcopy(prompt_messages)
  382. for prompt_message in prompt_messages:
  383. if isinstance(prompt_message, UserPromptMessage):
  384. if isinstance(prompt_message.content, list):
  385. prompt_message.content = "\n".join(
  386. [
  387. content.data
  388. if content.type == PromptMessageContentType.TEXT
  389. else "[image]"
  390. if content.type == PromptMessageContentType.IMAGE
  391. else "[file]"
  392. for content in prompt_message.content
  393. ]
  394. )
  395. return prompt_messages
  396. def _organize_prompt_messages(self):
  397. prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
  398. self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
  399. query_prompt_messages = self._organize_user_query(self.query or "", [])
  400. self.history_prompt_messages = AgentHistoryPromptTransform(
  401. model_config=self.model_config,
  402. prompt_messages=[*query_prompt_messages, *self._current_thoughts],
  403. history_messages=self.history_prompt_messages,
  404. memory=self.memory,
  405. ).get_prompt()
  406. prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
  407. if len(self._current_thoughts) != 0:
  408. # clear messages after the first iteration
  409. prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
  410. return prompt_messages