assistant_fc_runner.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. import json
  2. import logging
  3. from collections.abc import Generator
  4. from typing import Any, Union
  5. from core.application_queue_manager import PublishFrom
  6. from core.features.assistant_base_runner import BaseAssistantApplicationRunner
  7. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
  8. from core.model_runtime.entities.message_entities import (
  9. AssistantPromptMessage,
  10. PromptMessage,
  11. PromptMessageTool,
  12. SystemPromptMessage,
  13. ToolPromptMessage,
  14. UserPromptMessage,
  15. )
  16. from core.tools.errors import (
  17. ToolInvokeError,
  18. ToolNotFoundError,
  19. ToolNotSupportedError,
  20. ToolParameterValidationError,
  21. ToolProviderCredentialValidationError,
  22. ToolProviderNotFoundError,
  23. )
  24. from models.model import Conversation, Message, MessageAgentThought
  25. logger = logging.getLogger(__name__)
  26. class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
  27. def run(self, conversation: Conversation,
  28. message: Message,
  29. query: str,
  30. ) -> Generator[LLMResultChunk, None, None]:
  31. """
  32. Run FunctionCall agent application
  33. """
  34. app_orchestration_config = self.app_orchestration_config
  35. prompt_template = self.app_orchestration_config.prompt_template.simple_prompt_template or ''
  36. prompt_messages = self.history_prompt_messages
  37. prompt_messages = self.organize_prompt_messages(
  38. prompt_template=prompt_template,
  39. query=query,
  40. prompt_messages=prompt_messages
  41. )
  42. # convert tools into ModelRuntime Tool format
  43. prompt_messages_tools: list[PromptMessageTool] = []
  44. tool_instances = {}
  45. for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
  46. try:
  47. prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
  48. except Exception:
  49. # api tool may be deleted
  50. continue
  51. # save tool entity
  52. tool_instances[tool.tool_name] = tool_entity
  53. # save prompt tool
  54. prompt_messages_tools.append(prompt_tool)
  55. # convert dataset tools into ModelRuntime Tool format
  56. for dataset_tool in self.dataset_tools:
  57. prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
  58. # save prompt tool
  59. prompt_messages_tools.append(prompt_tool)
  60. # save tool entity
  61. tool_instances[dataset_tool.identity.name] = dataset_tool
  62. iteration_step = 1
  63. max_iteration_steps = min(app_orchestration_config.agent.max_iteration, 5) + 1
  64. # continue to run until there is not any tool call
  65. function_call_state = True
  66. agent_thoughts: list[MessageAgentThought] = []
  67. llm_usage = {
  68. 'usage': None
  69. }
  70. final_answer = ''
  71. def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
  72. if not final_llm_usage_dict['usage']:
  73. final_llm_usage_dict['usage'] = usage
  74. else:
  75. llm_usage = final_llm_usage_dict['usage']
  76. llm_usage.prompt_tokens += usage.prompt_tokens
  77. llm_usage.completion_tokens += usage.completion_tokens
  78. llm_usage.prompt_price += usage.prompt_price
  79. llm_usage.completion_price += usage.completion_price
  80. model_instance = self.model_instance
  81. while function_call_state and iteration_step <= max_iteration_steps:
  82. function_call_state = False
  83. if iteration_step == max_iteration_steps:
  84. # the last iteration, remove all tools
  85. prompt_messages_tools = []
  86. message_file_ids = []
  87. agent_thought = self.create_agent_thought(
  88. message_id=message.id,
  89. message='',
  90. tool_name='',
  91. tool_input='',
  92. messages_ids=message_file_ids
  93. )
  94. # recalc llm max tokens
  95. self.recalc_llm_max_tokens(self.model_config, prompt_messages)
  96. # invoke model
  97. chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
  98. prompt_messages=prompt_messages,
  99. model_parameters=app_orchestration_config.model_config.parameters,
  100. tools=prompt_messages_tools,
  101. stop=app_orchestration_config.model_config.stop,
  102. stream=self.stream_tool_call,
  103. user=self.user_id,
  104. callbacks=[],
  105. )
  106. tool_calls: list[tuple[str, str, dict[str, Any]]] = []
  107. # save full response
  108. response = ''
  109. # save tool call names and inputs
  110. tool_call_names = ''
  111. tool_call_inputs = ''
  112. current_llm_usage = None
  113. if self.stream_tool_call:
  114. is_first_chunk = True
  115. for chunk in chunks:
  116. if is_first_chunk:
  117. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  118. is_first_chunk = False
  119. # check if there is any tool call
  120. if self.check_tool_calls(chunk):
  121. function_call_state = True
  122. tool_calls.extend(self.extract_tool_calls(chunk))
  123. tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
  124. try:
  125. tool_call_inputs = json.dumps({
  126. tool_call[1]: tool_call[2] for tool_call in tool_calls
  127. }, ensure_ascii=False)
  128. except json.JSONDecodeError as e:
  129. # ensure ascii to avoid encoding error
  130. tool_call_inputs = json.dumps({
  131. tool_call[1]: tool_call[2] for tool_call in tool_calls
  132. })
  133. if chunk.delta.message and chunk.delta.message.content:
  134. if isinstance(chunk.delta.message.content, list):
  135. for content in chunk.delta.message.content:
  136. response += content.data
  137. else:
  138. response += chunk.delta.message.content
  139. if chunk.delta.usage:
  140. increase_usage(llm_usage, chunk.delta.usage)
  141. current_llm_usage = chunk.delta.usage
  142. yield chunk
  143. else:
  144. result: LLMResult = chunks
  145. # check if there is any tool call
  146. if self.check_blocking_tool_calls(result):
  147. function_call_state = True
  148. tool_calls.extend(self.extract_blocking_tool_calls(result))
  149. tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
  150. try:
  151. tool_call_inputs = json.dumps({
  152. tool_call[1]: tool_call[2] for tool_call in tool_calls
  153. }, ensure_ascii=False)
  154. except json.JSONDecodeError as e:
  155. # ensure ascii to avoid encoding error
  156. tool_call_inputs = json.dumps({
  157. tool_call[1]: tool_call[2] for tool_call in tool_calls
  158. })
  159. if result.usage:
  160. increase_usage(llm_usage, result.usage)
  161. current_llm_usage = result.usage
  162. if result.message and result.message.content:
  163. if isinstance(result.message.content, list):
  164. for content in result.message.content:
  165. response += content.data
  166. else:
  167. response += result.message.content
  168. if not result.message.content:
  169. result.message.content = ''
  170. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  171. yield LLMResultChunk(
  172. model=model_instance.model,
  173. prompt_messages=result.prompt_messages,
  174. system_fingerprint=result.system_fingerprint,
  175. delta=LLMResultChunkDelta(
  176. index=0,
  177. message=result.message,
  178. usage=result.usage,
  179. )
  180. )
  181. if tool_calls:
  182. prompt_messages.append(AssistantPromptMessage(
  183. content='',
  184. name='',
  185. tool_calls=[AssistantPromptMessage.ToolCall(
  186. id=tool_call[0],
  187. type='function',
  188. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  189. name=tool_call[1],
  190. arguments=json.dumps(tool_call[2], ensure_ascii=False)
  191. )
  192. ) for tool_call in tool_calls]
  193. ))
  194. # save thought
  195. self.save_agent_thought(
  196. agent_thought=agent_thought,
  197. tool_name=tool_call_names,
  198. tool_input=tool_call_inputs,
  199. thought=response,
  200. observation=None,
  201. answer=response,
  202. messages_ids=[],
  203. llm_usage=current_llm_usage
  204. )
  205. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  206. final_answer += response + '\n'
  207. # update prompt messages
  208. if response.strip():
  209. prompt_messages.append(AssistantPromptMessage(
  210. content=response,
  211. ))
  212. # call tools
  213. tool_responses = []
  214. for tool_call_id, tool_call_name, tool_call_args in tool_calls:
  215. tool_instance = tool_instances.get(tool_call_name)
  216. if not tool_instance:
  217. tool_response = {
  218. "tool_call_id": tool_call_id,
  219. "tool_call_name": tool_call_name,
  220. "tool_response": f"there is not a tool named {tool_call_name}"
  221. }
  222. tool_responses.append(tool_response)
  223. else:
  224. # invoke tool
  225. error_response = None
  226. try:
  227. tool_invoke_message = tool_instance.invoke(
  228. user_id=self.user_id,
  229. tool_parameters=tool_call_args,
  230. )
  231. # transform tool invoke message to get LLM friendly message
  232. tool_invoke_message = self.transform_tool_invoke_messages(tool_invoke_message)
  233. # extract binary data from tool invoke message
  234. binary_files = self.extract_tool_response_binary(tool_invoke_message)
  235. # create message file
  236. message_files = self.create_message_files(binary_files)
  237. # publish files
  238. for message_file, save_as in message_files:
  239. if save_as:
  240. self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as)
  241. # publish message file
  242. self.queue_manager.publish_message_file(message_file, PublishFrom.APPLICATION_MANAGER)
  243. # add message file ids
  244. message_file_ids.append(message_file.id)
  245. except ToolProviderCredentialValidationError as e:
  246. error_response = "Please check your tool provider credentials"
  247. except (
  248. ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
  249. ) as e:
  250. error_response = f"there is not a tool named {tool_call_name}"
  251. except (
  252. ToolParameterValidationError
  253. ) as e:
  254. error_response = f"tool parameters validation error: {e}, please check your tool parameters"
  255. except ToolInvokeError as e:
  256. error_response = f"tool invoke error: {e}"
  257. except Exception as e:
  258. error_response = f"unknown error: {e}"
  259. if error_response:
  260. observation = error_response
  261. tool_response = {
  262. "tool_call_id": tool_call_id,
  263. "tool_call_name": tool_call_name,
  264. "tool_response": error_response
  265. }
  266. tool_responses.append(tool_response)
  267. else:
  268. observation = self._convert_tool_response_to_str(tool_invoke_message)
  269. tool_response = {
  270. "tool_call_id": tool_call_id,
  271. "tool_call_name": tool_call_name,
  272. "tool_response": observation
  273. }
  274. tool_responses.append(tool_response)
  275. prompt_messages = self.organize_prompt_messages(
  276. prompt_template=prompt_template,
  277. query=None,
  278. tool_call_id=tool_call_id,
  279. tool_call_name=tool_call_name,
  280. tool_response=tool_response['tool_response'],
  281. prompt_messages=prompt_messages,
  282. )
  283. if len(tool_responses) > 0:
  284. # save agent thought
  285. self.save_agent_thought(
  286. agent_thought=agent_thought,
  287. tool_name=None,
  288. tool_input=None,
  289. thought=None,
  290. observation=tool_response['tool_response'],
  291. answer=None,
  292. messages_ids=message_file_ids
  293. )
  294. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  295. # update prompt tool
  296. for prompt_tool in prompt_messages_tools:
  297. self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
  298. iteration_step += 1
  299. self.update_db_variables(self.variables_pool, self.db_variables_pool)
  300. # publish end event
  301. self.queue_manager.publish_message_end(LLMResult(
  302. model=model_instance.model,
  303. prompt_messages=prompt_messages,
  304. message=AssistantPromptMessage(
  305. content=final_answer,
  306. ),
  307. usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
  308. system_fingerprint=''
  309. ), PublishFrom.APPLICATION_MANAGER)
  310. def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
  311. """
  312. Check if there is any tool call in llm result chunk
  313. """
  314. if llm_result_chunk.delta.message.tool_calls:
  315. return True
  316. return False
  317. def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
  318. """
  319. Check if there is any blocking tool call in llm result
  320. """
  321. if llm_result.message.tool_calls:
  322. return True
  323. return False
  324. def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
  325. """
  326. Extract tool calls from llm result chunk
  327. Returns:
  328. List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
  329. """
  330. tool_calls = []
  331. for prompt_message in llm_result_chunk.delta.message.tool_calls:
  332. tool_calls.append((
  333. prompt_message.id,
  334. prompt_message.function.name,
  335. json.loads(prompt_message.function.arguments),
  336. ))
  337. return tool_calls
  338. def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
  339. """
  340. Extract blocking tool calls from llm result
  341. Returns:
  342. List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
  343. """
  344. tool_calls = []
  345. for prompt_message in llm_result.message.tool_calls:
  346. tool_calls.append((
  347. prompt_message.id,
  348. prompt_message.function.name,
  349. json.loads(prompt_message.function.arguments),
  350. ))
  351. return tool_calls
  352. def organize_prompt_messages(self, prompt_template: str,
  353. query: str = None,
  354. tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
  355. prompt_messages: list[PromptMessage] = None
  356. ) -> list[PromptMessage]:
  357. """
  358. Organize prompt messages
  359. """
  360. if not prompt_messages:
  361. prompt_messages = [
  362. SystemPromptMessage(content=prompt_template),
  363. UserPromptMessage(content=query),
  364. ]
  365. else:
  366. if tool_response:
  367. prompt_messages = prompt_messages.copy()
  368. prompt_messages.append(
  369. ToolPromptMessage(
  370. content=tool_response,
  371. tool_call_id=tool_call_id,
  372. name=tool_call_name,
  373. )
  374. )
  375. return prompt_messages