fc_agent_runner.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. import json
  2. import logging
  3. from collections.abc import Generator
  4. from copy import deepcopy
  5. from typing import Any, Union
  6. from core.agent.base_agent_runner import BaseAgentRunner
  7. from core.app.apps.base_app_queue_manager import PublishFrom
  8. from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
  9. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
  10. from core.model_runtime.entities.message_entities import (
  11. AssistantPromptMessage,
  12. PromptMessage,
  13. PromptMessageContentType,
  14. PromptMessageTool,
  15. SystemPromptMessage,
  16. TextPromptMessageContent,
  17. ToolPromptMessage,
  18. UserPromptMessage,
  19. )
  20. from core.tools.entities.tool_entities import ToolInvokeMeta
  21. from core.tools.tool_engine import ToolEngine
  22. from models.model import Message
  23. logger = logging.getLogger(__name__)
  24. class FunctionCallAgentRunner(BaseAgentRunner):
  25. def run(self, message: Message,
  26. query: str,
  27. ) -> Generator[LLMResultChunk, None, None]:
  28. """
  29. Run FunctionCall agent application
  30. """
  31. app_generate_entity = self.application_generate_entity
  32. app_config = self.app_config
  33. prompt_template = app_config.prompt_template.simple_prompt_template or ''
  34. prompt_messages = self.history_prompt_messages
  35. prompt_messages = self._init_system_message(prompt_template, prompt_messages)
  36. prompt_messages = self._organize_user_query(query, prompt_messages)
  37. # convert tools into ModelRuntime Tool format
  38. prompt_messages_tools: list[PromptMessageTool] = []
  39. tool_instances = {}
  40. for tool in app_config.agent.tools if app_config.agent else []:
  41. try:
  42. prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
  43. except Exception:
  44. # api tool may be deleted
  45. continue
  46. # save tool entity
  47. tool_instances[tool.tool_name] = tool_entity
  48. # save prompt tool
  49. prompt_messages_tools.append(prompt_tool)
  50. # convert dataset tools into ModelRuntime Tool format
  51. for dataset_tool in self.dataset_tools:
  52. prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
  53. # save prompt tool
  54. prompt_messages_tools.append(prompt_tool)
  55. # save tool entity
  56. tool_instances[dataset_tool.identity.name] = dataset_tool
  57. iteration_step = 1
  58. max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
  59. # continue to run until there is not any tool call
  60. function_call_state = True
  61. llm_usage = {
  62. 'usage': None
  63. }
  64. final_answer = ''
  65. def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
  66. if not final_llm_usage_dict['usage']:
  67. final_llm_usage_dict['usage'] = usage
  68. else:
  69. llm_usage = final_llm_usage_dict['usage']
  70. llm_usage.prompt_tokens += usage.prompt_tokens
  71. llm_usage.completion_tokens += usage.completion_tokens
  72. llm_usage.prompt_price += usage.prompt_price
  73. llm_usage.completion_price += usage.completion_price
  74. model_instance = self.model_instance
  75. while function_call_state and iteration_step <= max_iteration_steps:
  76. function_call_state = False
  77. if iteration_step == max_iteration_steps:
  78. # the last iteration, remove all tools
  79. prompt_messages_tools = []
  80. message_file_ids = []
  81. agent_thought = self.create_agent_thought(
  82. message_id=message.id,
  83. message='',
  84. tool_name='',
  85. tool_input='',
  86. messages_ids=message_file_ids
  87. )
  88. # recalc llm max tokens
  89. self.recalc_llm_max_tokens(self.model_config, prompt_messages)
  90. # invoke model
  91. chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
  92. prompt_messages=prompt_messages,
  93. model_parameters=app_generate_entity.model_config.parameters,
  94. tools=prompt_messages_tools,
  95. stop=app_generate_entity.model_config.stop,
  96. stream=self.stream_tool_call,
  97. user=self.user_id,
  98. callbacks=[],
  99. )
  100. tool_calls: list[tuple[str, str, dict[str, Any]]] = []
  101. # save full response
  102. response = ''
  103. # save tool call names and inputs
  104. tool_call_names = ''
  105. tool_call_inputs = ''
  106. current_llm_usage = None
  107. if self.stream_tool_call:
  108. is_first_chunk = True
  109. for chunk in chunks:
  110. if is_first_chunk:
  111. self.queue_manager.publish(QueueAgentThoughtEvent(
  112. agent_thought_id=agent_thought.id
  113. ), PublishFrom.APPLICATION_MANAGER)
  114. is_first_chunk = False
  115. # check if there is any tool call
  116. if self.check_tool_calls(chunk):
  117. function_call_state = True
  118. tool_calls.extend(self.extract_tool_calls(chunk))
  119. tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
  120. try:
  121. tool_call_inputs = json.dumps({
  122. tool_call[1]: tool_call[2] for tool_call in tool_calls
  123. }, ensure_ascii=False)
  124. except json.JSONDecodeError as e:
  125. # ensure ascii to avoid encoding error
  126. tool_call_inputs = json.dumps({
  127. tool_call[1]: tool_call[2] for tool_call in tool_calls
  128. })
  129. if chunk.delta.message and chunk.delta.message.content:
  130. if isinstance(chunk.delta.message.content, list):
  131. for content in chunk.delta.message.content:
  132. response += content.data
  133. else:
  134. response += chunk.delta.message.content
  135. if chunk.delta.usage:
  136. increase_usage(llm_usage, chunk.delta.usage)
  137. current_llm_usage = chunk.delta.usage
  138. yield chunk
  139. else:
  140. result: LLMResult = chunks
  141. # check if there is any tool call
  142. if self.check_blocking_tool_calls(result):
  143. function_call_state = True
  144. tool_calls.extend(self.extract_blocking_tool_calls(result))
  145. tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
  146. try:
  147. tool_call_inputs = json.dumps({
  148. tool_call[1]: tool_call[2] for tool_call in tool_calls
  149. }, ensure_ascii=False)
  150. except json.JSONDecodeError as e:
  151. # ensure ascii to avoid encoding error
  152. tool_call_inputs = json.dumps({
  153. tool_call[1]: tool_call[2] for tool_call in tool_calls
  154. })
  155. if result.usage:
  156. increase_usage(llm_usage, result.usage)
  157. current_llm_usage = result.usage
  158. if result.message and result.message.content:
  159. if isinstance(result.message.content, list):
  160. for content in result.message.content:
  161. response += content.data
  162. else:
  163. response += result.message.content
  164. if not result.message.content:
  165. result.message.content = ''
  166. self.queue_manager.publish(QueueAgentThoughtEvent(
  167. agent_thought_id=agent_thought.id
  168. ), PublishFrom.APPLICATION_MANAGER)
  169. yield LLMResultChunk(
  170. model=model_instance.model,
  171. prompt_messages=result.prompt_messages,
  172. system_fingerprint=result.system_fingerprint,
  173. delta=LLMResultChunkDelta(
  174. index=0,
  175. message=result.message,
  176. usage=result.usage,
  177. )
  178. )
  179. assistant_message = AssistantPromptMessage(
  180. content='',
  181. tool_calls=[]
  182. )
  183. if tool_calls:
  184. assistant_message.tool_calls=[
  185. AssistantPromptMessage.ToolCall(
  186. id=tool_call[0],
  187. type='function',
  188. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  189. name=tool_call[1],
  190. arguments=json.dumps(tool_call[2], ensure_ascii=False)
  191. )
  192. ) for tool_call in tool_calls
  193. ]
  194. else:
  195. assistant_message.content = response
  196. prompt_messages.append(assistant_message)
  197. # save thought
  198. self.save_agent_thought(
  199. agent_thought=agent_thought,
  200. tool_name=tool_call_names,
  201. tool_input=tool_call_inputs,
  202. thought=response,
  203. tool_invoke_meta=None,
  204. observation=None,
  205. answer=response,
  206. messages_ids=[],
  207. llm_usage=current_llm_usage
  208. )
  209. self.queue_manager.publish(QueueAgentThoughtEvent(
  210. agent_thought_id=agent_thought.id
  211. ), PublishFrom.APPLICATION_MANAGER)
  212. final_answer += response + '\n'
  213. # call tools
  214. tool_responses = []
  215. for tool_call_id, tool_call_name, tool_call_args in tool_calls:
  216. tool_instance = tool_instances.get(tool_call_name)
  217. if not tool_instance:
  218. tool_response = {
  219. "tool_call_id": tool_call_id,
  220. "tool_call_name": tool_call_name,
  221. "tool_response": f"there is not a tool named {tool_call_name}",
  222. "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict()
  223. }
  224. else:
  225. # invoke tool
  226. tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
  227. tool=tool_instance,
  228. tool_parameters=tool_call_args,
  229. user_id=self.user_id,
  230. tenant_id=self.tenant_id,
  231. message=self.message,
  232. invoke_from=self.application_generate_entity.invoke_from,
  233. agent_tool_callback=self.agent_callback,
  234. )
  235. # publish files
  236. for message_file, save_as in message_files:
  237. if save_as:
  238. self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as)
  239. # publish message file
  240. self.queue_manager.publish(QueueMessageFileEvent(
  241. message_file_id=message_file.id
  242. ), PublishFrom.APPLICATION_MANAGER)
  243. # add message file ids
  244. message_file_ids.append(message_file.id)
  245. tool_response = {
  246. "tool_call_id": tool_call_id,
  247. "tool_call_name": tool_call_name,
  248. "tool_response": tool_invoke_response,
  249. "meta": tool_invoke_meta.to_dict()
  250. }
  251. tool_responses.append(tool_response)
  252. prompt_messages = self._organize_assistant_message(
  253. tool_call_id=tool_call_id,
  254. tool_call_name=tool_call_name,
  255. tool_response=tool_response['tool_response'],
  256. prompt_messages=prompt_messages,
  257. )
  258. if len(tool_responses) > 0:
  259. # save agent thought
  260. self.save_agent_thought(
  261. agent_thought=agent_thought,
  262. tool_name=None,
  263. tool_input=None,
  264. thought=None,
  265. tool_invoke_meta={
  266. tool_response['tool_call_name']: tool_response['meta']
  267. for tool_response in tool_responses
  268. },
  269. observation={
  270. tool_response['tool_call_name']: tool_response['tool_response']
  271. for tool_response in tool_responses
  272. },
  273. answer=None,
  274. messages_ids=message_file_ids
  275. )
  276. self.queue_manager.publish(QueueAgentThoughtEvent(
  277. agent_thought_id=agent_thought.id
  278. ), PublishFrom.APPLICATION_MANAGER)
  279. # update prompt tool
  280. for prompt_tool in prompt_messages_tools:
  281. self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
  282. iteration_step += 1
  283. prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
  284. self.update_db_variables(self.variables_pool, self.db_variables_pool)
  285. # publish end event
  286. self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
  287. model=model_instance.model,
  288. prompt_messages=prompt_messages,
  289. message=AssistantPromptMessage(
  290. content=final_answer
  291. ),
  292. usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
  293. system_fingerprint=''
  294. )), PublishFrom.APPLICATION_MANAGER)
  295. def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
  296. """
  297. Check if there is any tool call in llm result chunk
  298. """
  299. if llm_result_chunk.delta.message.tool_calls:
  300. return True
  301. return False
  302. def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
  303. """
  304. Check if there is any blocking tool call in llm result
  305. """
  306. if llm_result.message.tool_calls:
  307. return True
  308. return False
  309. def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
  310. """
  311. Extract tool calls from llm result chunk
  312. Returns:
  313. List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
  314. """
  315. tool_calls = []
  316. for prompt_message in llm_result_chunk.delta.message.tool_calls:
  317. tool_calls.append((
  318. prompt_message.id,
  319. prompt_message.function.name,
  320. json.loads(prompt_message.function.arguments),
  321. ))
  322. return tool_calls
  323. def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
  324. """
  325. Extract blocking tool calls from llm result
  326. Returns:
  327. List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
  328. """
  329. tool_calls = []
  330. for prompt_message in llm_result.message.tool_calls:
  331. tool_calls.append((
  332. prompt_message.id,
  333. prompt_message.function.name,
  334. json.loads(prompt_message.function.arguments),
  335. ))
  336. return tool_calls
  337. def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
  338. """
  339. Initialize system message
  340. """
  341. if not prompt_messages and prompt_template:
  342. return [
  343. SystemPromptMessage(content=prompt_template),
  344. ]
  345. if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
  346. prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
  347. return prompt_messages
  348. def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
  349. """
  350. Organize user query
  351. """
  352. if self.files:
  353. prompt_message_contents = [TextPromptMessageContent(data=query)]
  354. for file_obj in self.files:
  355. prompt_message_contents.append(file_obj.prompt_message_content)
  356. prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
  357. else:
  358. prompt_messages.append(UserPromptMessage(content=query))
  359. return prompt_messages
  360. def _organize_assistant_message(self, tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
  361. prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
  362. """
  363. Organize assistant message
  364. """
  365. prompt_messages = deepcopy(prompt_messages)
  366. if tool_response is not None:
  367. prompt_messages.append(
  368. ToolPromptMessage(
  369. content=tool_response,
  370. tool_call_id=tool_call_id,
  371. name=tool_call_name,
  372. )
  373. )
  374. return prompt_messages
  375. def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  376. """
  377. As for now, gpt supports both fc and vision at the first iteration.
  378. We need to remove the image messages from the prompt messages at the first iteration.
  379. """
  380. prompt_messages = deepcopy(prompt_messages)
  381. for prompt_message in prompt_messages:
  382. if isinstance(prompt_message, UserPromptMessage):
  383. if isinstance(prompt_message.content, list):
  384. prompt_message.content = '\n'.join([
  385. content.data if content.type == PromptMessageContentType.TEXT else
  386. '[image]' if content.type == PromptMessageContentType.IMAGE else
  387. '[file]'
  388. for content in prompt_message.content
  389. ])
  390. return prompt_messages