assistant_fc_runner.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. import json
  2. import logging
  3. from typing import Any, Dict, Generator, List, Tuple, Union
  4. from core.application_queue_manager import PublishFrom
  5. from core.features.assistant_base_runner import BaseAssistantApplicationRunner
  6. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
  7. from core.model_runtime.entities.message_entities import (
  8. AssistantPromptMessage,
  9. PromptMessage,
  10. PromptMessageTool,
  11. SystemPromptMessage,
  12. ToolPromptMessage,
  13. UserPromptMessage,
  14. )
  15. from core.tools.errors import (
  16. ToolInvokeError,
  17. ToolNotFoundError,
  18. ToolNotSupportedError,
  19. ToolParameterValidationError,
  20. ToolProviderCredentialValidationError,
  21. ToolProviderNotFoundError,
  22. )
  23. from models.model import Conversation, Message, MessageAgentThought
  24. logger = logging.getLogger(__name__)
  25. class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
  26. def run(self, conversation: Conversation,
  27. message: Message,
  28. query: str,
  29. ) -> Generator[LLMResultChunk, None, None]:
  30. """
  31. Run FunctionCall agent application
  32. """
  33. app_orchestration_config = self.app_orchestration_config
  34. prompt_template = self.app_orchestration_config.prompt_template.simple_prompt_template or ''
  35. prompt_messages = self.history_prompt_messages
  36. prompt_messages = self.organize_prompt_messages(
  37. prompt_template=prompt_template,
  38. query=query,
  39. prompt_messages=prompt_messages
  40. )
  41. # convert tools into ModelRuntime Tool format
  42. prompt_messages_tools: List[PromptMessageTool] = []
  43. tool_instances = {}
  44. for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
  45. try:
  46. prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
  47. except Exception:
  48. # api tool may be deleted
  49. continue
  50. # save tool entity
  51. tool_instances[tool.tool_name] = tool_entity
  52. # save prompt tool
  53. prompt_messages_tools.append(prompt_tool)
  54. # convert dataset tools into ModelRuntime Tool format
  55. for dataset_tool in self.dataset_tools:
  56. prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
  57. # save prompt tool
  58. prompt_messages_tools.append(prompt_tool)
  59. # save tool entity
  60. tool_instances[dataset_tool.identity.name] = dataset_tool
  61. iteration_step = 1
  62. max_iteration_steps = min(app_orchestration_config.agent.max_iteration, 5) + 1
  63. # continue to run until there is not any tool call
  64. function_call_state = True
  65. agent_thoughts: List[MessageAgentThought] = []
  66. llm_usage = {
  67. 'usage': None
  68. }
  69. final_answer = ''
  70. def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage):
  71. if not final_llm_usage_dict['usage']:
  72. final_llm_usage_dict['usage'] = usage
  73. else:
  74. llm_usage = final_llm_usage_dict['usage']
  75. llm_usage.prompt_tokens += usage.prompt_tokens
  76. llm_usage.completion_tokens += usage.completion_tokens
  77. llm_usage.prompt_price += usage.prompt_price
  78. llm_usage.completion_price += usage.completion_price
  79. model_instance = self.model_instance
  80. while function_call_state and iteration_step <= max_iteration_steps:
  81. function_call_state = False
  82. if iteration_step == max_iteration_steps:
  83. # the last iteration, remove all tools
  84. prompt_messages_tools = []
  85. message_file_ids = []
  86. agent_thought = self.create_agent_thought(
  87. message_id=message.id,
  88. message='',
  89. tool_name='',
  90. tool_input='',
  91. messages_ids=message_file_ids
  92. )
  93. # recale llm max tokens
  94. self.recale_llm_max_tokens(self.model_config, prompt_messages)
  95. # invoke model
  96. chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
  97. prompt_messages=prompt_messages,
  98. model_parameters=app_orchestration_config.model_config.parameters,
  99. tools=prompt_messages_tools,
  100. stop=app_orchestration_config.model_config.stop,
  101. stream=self.stream_tool_call,
  102. user=self.user_id,
  103. callbacks=[],
  104. )
  105. tool_calls: List[Tuple[str, str, Dict[str, Any]]] = []
  106. # save full response
  107. response = ''
  108. # save tool call names and inputs
  109. tool_call_names = ''
  110. tool_call_inputs = ''
  111. current_llm_usage = None
  112. if self.stream_tool_call:
  113. is_first_chunk = True
  114. for chunk in chunks:
  115. if is_first_chunk:
  116. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  117. is_first_chunk = False
  118. # check if there is any tool call
  119. if self.check_tool_calls(chunk):
  120. function_call_state = True
  121. tool_calls.extend(self.extract_tool_calls(chunk))
  122. tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
  123. try:
  124. tool_call_inputs = json.dumps({
  125. tool_call[1]: tool_call[2] for tool_call in tool_calls
  126. }, ensure_ascii=False)
  127. except json.JSONDecodeError as e:
  128. # ensure ascii to avoid encoding error
  129. tool_call_inputs = json.dumps({
  130. tool_call[1]: tool_call[2] for tool_call in tool_calls
  131. })
  132. if chunk.delta.message and chunk.delta.message.content:
  133. if isinstance(chunk.delta.message.content, list):
  134. for content in chunk.delta.message.content:
  135. response += content.data
  136. else:
  137. response += chunk.delta.message.content
  138. if chunk.delta.usage:
  139. increase_usage(llm_usage, chunk.delta.usage)
  140. current_llm_usage = chunk.delta.usage
  141. yield chunk
  142. else:
  143. result: LLMResult = chunks
  144. # check if there is any tool call
  145. if self.check_blocking_tool_calls(result):
  146. function_call_state = True
  147. tool_calls.extend(self.extract_blocking_tool_calls(result))
  148. tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
  149. try:
  150. tool_call_inputs = json.dumps({
  151. tool_call[1]: tool_call[2] for tool_call in tool_calls
  152. }, ensure_ascii=False)
  153. except json.JSONDecodeError as e:
  154. # ensure ascii to avoid encoding error
  155. tool_call_inputs = json.dumps({
  156. tool_call[1]: tool_call[2] for tool_call in tool_calls
  157. })
  158. if result.usage:
  159. increase_usage(llm_usage, result.usage)
  160. current_llm_usage = result.usage
  161. if result.message and result.message.content:
  162. if isinstance(result.message.content, list):
  163. for content in result.message.content:
  164. response += content.data
  165. else:
  166. response += result.message.content
  167. if not result.message.content:
  168. result.message.content = ''
  169. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  170. yield LLMResultChunk(
  171. model=model_instance.model,
  172. prompt_messages=result.prompt_messages,
  173. system_fingerprint=result.system_fingerprint,
  174. delta=LLMResultChunkDelta(
  175. index=0,
  176. message=result.message,
  177. usage=result.usage,
  178. )
  179. )
  180. if tool_calls:
  181. prompt_messages.append(AssistantPromptMessage(
  182. content='',
  183. name='',
  184. tool_calls=[AssistantPromptMessage.ToolCall(
  185. id=tool_call[0],
  186. type='function',
  187. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  188. name=tool_call[1],
  189. arguments=json.dumps(tool_call[2], ensure_ascii=False)
  190. )
  191. ) for tool_call in tool_calls]
  192. ))
  193. # save thought
  194. self.save_agent_thought(
  195. agent_thought=agent_thought,
  196. tool_name=tool_call_names,
  197. tool_input=tool_call_inputs,
  198. thought=response,
  199. observation=None,
  200. answer=response,
  201. messages_ids=[],
  202. llm_usage=current_llm_usage
  203. )
  204. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  205. final_answer += response + '\n'
  206. # update prompt messages
  207. if response.strip():
  208. prompt_messages.append(AssistantPromptMessage(
  209. content=response,
  210. ))
  211. # call tools
  212. tool_responses = []
  213. for tool_call_id, tool_call_name, tool_call_args in tool_calls:
  214. tool_instance = tool_instances.get(tool_call_name)
  215. if not tool_instance:
  216. tool_response = {
  217. "tool_call_id": tool_call_id,
  218. "tool_call_name": tool_call_name,
  219. "tool_response": f"there is not a tool named {tool_call_name}"
  220. }
  221. tool_responses.append(tool_response)
  222. else:
  223. # invoke tool
  224. error_response = None
  225. try:
  226. tool_invoke_message = tool_instance.invoke(
  227. user_id=self.user_id,
  228. tool_parameters=tool_call_args,
  229. )
  230. # transform tool invoke message to get LLM friendly message
  231. tool_invoke_message = self.transform_tool_invoke_messages(tool_invoke_message)
  232. # extract binary data from tool invoke message
  233. binary_files = self.extract_tool_response_binary(tool_invoke_message)
  234. # create message file
  235. message_files = self.create_message_files(binary_files)
  236. # publish files
  237. for message_file, save_as in message_files:
  238. if save_as:
  239. self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as)
  240. # publish message file
  241. self.queue_manager.publish_message_file(message_file, PublishFrom.APPLICATION_MANAGER)
  242. # add message file ids
  243. message_file_ids.append(message_file.id)
  244. except ToolProviderCredentialValidationError as e:
  245. error_response = "Please check your tool provider credentials"
  246. except (
  247. ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
  248. ) as e:
  249. error_response = f"there is not a tool named {tool_call_name}"
  250. except (
  251. ToolParameterValidationError
  252. ) as e:
  253. error_response = f"tool parameters validation error: {e}, please check your tool parameters"
  254. except ToolInvokeError as e:
  255. error_response = f"tool invoke error: {e}"
  256. except Exception as e:
  257. error_response = f"unknown error: {e}"
  258. if error_response:
  259. observation = error_response
  260. tool_response = {
  261. "tool_call_id": tool_call_id,
  262. "tool_call_name": tool_call_name,
  263. "tool_response": error_response
  264. }
  265. tool_responses.append(tool_response)
  266. else:
  267. observation = self._convert_tool_response_to_str(tool_invoke_message)
  268. tool_response = {
  269. "tool_call_id": tool_call_id,
  270. "tool_call_name": tool_call_name,
  271. "tool_response": observation
  272. }
  273. tool_responses.append(tool_response)
  274. prompt_messages = self.organize_prompt_messages(
  275. prompt_template=prompt_template,
  276. query=None,
  277. tool_call_id=tool_call_id,
  278. tool_call_name=tool_call_name,
  279. tool_response=tool_response['tool_response'],
  280. prompt_messages=prompt_messages,
  281. )
  282. if len(tool_responses) > 0:
  283. # save agent thought
  284. self.save_agent_thought(
  285. agent_thought=agent_thought,
  286. tool_name=None,
  287. tool_input=None,
  288. thought=None,
  289. observation=tool_response['tool_response'],
  290. answer=None,
  291. messages_ids=message_file_ids
  292. )
  293. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  294. # update prompt tool
  295. for prompt_tool in prompt_messages_tools:
  296. self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
  297. iteration_step += 1
  298. self.update_db_variables(self.variables_pool, self.db_variables_pool)
  299. # publish end event
  300. self.queue_manager.publish_message_end(LLMResult(
  301. model=model_instance.model,
  302. prompt_messages=prompt_messages,
  303. message=AssistantPromptMessage(
  304. content=final_answer,
  305. ),
  306. usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
  307. system_fingerprint=''
  308. ), PublishFrom.APPLICATION_MANAGER)
  309. def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
  310. """
  311. Check if there is any tool call in llm result chunk
  312. """
  313. if llm_result_chunk.delta.message.tool_calls:
  314. return True
  315. return False
  316. def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
  317. """
  318. Check if there is any blocking tool call in llm result
  319. """
  320. if llm_result.message.tool_calls:
  321. return True
  322. return False
  323. def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
  324. """
  325. Extract tool calls from llm result chunk
  326. Returns:
  327. List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
  328. """
  329. tool_calls = []
  330. for prompt_message in llm_result_chunk.delta.message.tool_calls:
  331. tool_calls.append((
  332. prompt_message.id,
  333. prompt_message.function.name,
  334. json.loads(prompt_message.function.arguments),
  335. ))
  336. return tool_calls
  337. def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
  338. """
  339. Extract blocking tool calls from llm result
  340. Returns:
  341. List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
  342. """
  343. tool_calls = []
  344. for prompt_message in llm_result.message.tool_calls:
  345. tool_calls.append((
  346. prompt_message.id,
  347. prompt_message.function.name,
  348. json.loads(prompt_message.function.arguments),
  349. ))
  350. return tool_calls
  351. def organize_prompt_messages(self, prompt_template: str,
  352. query: str = None,
  353. tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
  354. prompt_messages: list[PromptMessage] = None
  355. ) -> list[PromptMessage]:
  356. """
  357. Organize prompt messages
  358. """
  359. if not prompt_messages:
  360. prompt_messages = [
  361. SystemPromptMessage(content=prompt_template),
  362. UserPromptMessage(content=query),
  363. ]
  364. else:
  365. if tool_response:
  366. prompt_messages = prompt_messages.copy()
  367. prompt_messages.append(
  368. ToolPromptMessage(
  369. content=tool_response,
  370. tool_call_id=tool_call_id,
  371. name=tool_call_name,
  372. )
  373. )
  374. return prompt_messages