assistant_fc_runner.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. import json
  2. import logging
  3. from typing import Any, Dict, Generator, List, Tuple, Union
  4. from core.application_queue_manager import PublishFrom
  5. from core.features.assistant_base_runner import BaseAssistantApplicationRunner
  6. from core.model_manager import ModelInstance
  7. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
  8. from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool,
  9. SystemPromptMessage, ToolPromptMessage, UserPromptMessage)
  10. from core.tools.errors import (ToolInvokeError, ToolNotFoundError, ToolNotSupportedError, ToolParameterValidationError,
  11. ToolProviderCredentialValidationError, ToolProviderNotFoundError)
  12. from models.model import Conversation, Message, MessageAgentThought
  13. logger = logging.getLogger(__name__)
  14. class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
  15. def run(self, conversation: Conversation,
  16. message: Message,
  17. query: str,
  18. ) -> Generator[LLMResultChunk, None, None]:
  19. """
  20. Run FunctionCall agent application
  21. """
  22. app_orchestration_config = self.app_orchestration_config
  23. prompt_template = self.app_orchestration_config.prompt_template.simple_prompt_template or ''
  24. prompt_messages = self.history_prompt_messages
  25. prompt_messages = self.organize_prompt_messages(
  26. prompt_template=prompt_template,
  27. query=query,
  28. prompt_messages=prompt_messages
  29. )
  30. # convert tools into ModelRuntime Tool format
  31. prompt_messages_tools: List[PromptMessageTool] = []
  32. tool_instances = {}
  33. for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
  34. try:
  35. prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
  36. except Exception:
  37. # api tool may be deleted
  38. continue
  39. # save tool entity
  40. tool_instances[tool.tool_name] = tool_entity
  41. # save prompt tool
  42. prompt_messages_tools.append(prompt_tool)
  43. # convert dataset tools into ModelRuntime Tool format
  44. for dataset_tool in self.dataset_tools:
  45. prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
  46. # save prompt tool
  47. prompt_messages_tools.append(prompt_tool)
  48. # save tool entity
  49. tool_instances[dataset_tool.identity.name] = dataset_tool
  50. iteration_step = 1
  51. max_iteration_steps = min(app_orchestration_config.agent.max_iteration, 5) + 1
  52. # continue to run until there is not any tool call
  53. function_call_state = True
  54. agent_thoughts: List[MessageAgentThought] = []
  55. llm_usage = {
  56. 'usage': None
  57. }
  58. final_answer = ''
  59. def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage):
  60. if not final_llm_usage_dict['usage']:
  61. final_llm_usage_dict['usage'] = usage
  62. else:
  63. llm_usage = final_llm_usage_dict['usage']
  64. llm_usage.prompt_tokens += usage.prompt_tokens
  65. llm_usage.completion_tokens += usage.completion_tokens
  66. llm_usage.prompt_price += usage.prompt_price
  67. llm_usage.completion_price += usage.completion_price
  68. model_instance = self.model_instance
  69. while function_call_state and iteration_step <= max_iteration_steps:
  70. function_call_state = False
  71. if iteration_step == max_iteration_steps:
  72. # the last iteration, remove all tools
  73. prompt_messages_tools = []
  74. message_file_ids = []
  75. agent_thought = self.create_agent_thought(
  76. message_id=message.id,
  77. message='',
  78. tool_name='',
  79. tool_input='',
  80. messages_ids=message_file_ids
  81. )
  82. # recale llm max tokens
  83. self.recale_llm_max_tokens(self.model_config, prompt_messages)
  84. # invoke model
  85. chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
  86. prompt_messages=prompt_messages,
  87. model_parameters=app_orchestration_config.model_config.parameters,
  88. tools=prompt_messages_tools,
  89. stop=app_orchestration_config.model_config.stop,
  90. stream=self.stream_tool_call,
  91. user=self.user_id,
  92. callbacks=[],
  93. )
  94. tool_calls: List[Tuple[str, str, Dict[str, Any]]] = []
  95. # save full response
  96. response = ''
  97. # save tool call names and inputs
  98. tool_call_names = ''
  99. tool_call_inputs = ''
  100. current_llm_usage = None
  101. if self.stream_tool_call:
  102. is_first_chunk = True
  103. for chunk in chunks:
  104. if is_first_chunk:
  105. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  106. is_first_chunk = False
  107. # check if there is any tool call
  108. if self.check_tool_calls(chunk):
  109. function_call_state = True
  110. tool_calls.extend(self.extract_tool_calls(chunk))
  111. tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
  112. try:
  113. tool_call_inputs = json.dumps({
  114. tool_call[1]: tool_call[2] for tool_call in tool_calls
  115. }, ensure_ascii=False)
  116. except json.JSONDecodeError as e:
  117. # ensure ascii to avoid encoding error
  118. tool_call_inputs = json.dumps({
  119. tool_call[1]: tool_call[2] for tool_call in tool_calls
  120. })
  121. if chunk.delta.message and chunk.delta.message.content:
  122. if isinstance(chunk.delta.message.content, list):
  123. for content in chunk.delta.message.content:
  124. response += content.data
  125. else:
  126. response += chunk.delta.message.content
  127. if chunk.delta.usage:
  128. increase_usage(llm_usage, chunk.delta.usage)
  129. current_llm_usage = chunk.delta.usage
  130. yield chunk
  131. else:
  132. result: LLMResult = chunks
  133. # check if there is any tool call
  134. if self.check_blocking_tool_calls(result):
  135. function_call_state = True
  136. tool_calls.extend(self.extract_blocking_tool_calls(result))
  137. tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
  138. try:
  139. tool_call_inputs = json.dumps({
  140. tool_call[1]: tool_call[2] for tool_call in tool_calls
  141. }, ensure_ascii=False)
  142. except json.JSONDecodeError as e:
  143. # ensure ascii to avoid encoding error
  144. tool_call_inputs = json.dumps({
  145. tool_call[1]: tool_call[2] for tool_call in tool_calls
  146. })
  147. if result.usage:
  148. increase_usage(llm_usage, result.usage)
  149. current_llm_usage = result.usage
  150. if result.message and result.message.content:
  151. if isinstance(result.message.content, list):
  152. for content in result.message.content:
  153. response += content.data
  154. else:
  155. response += result.message.content
  156. if not result.message.content:
  157. result.message.content = ''
  158. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  159. yield LLMResultChunk(
  160. model=model_instance.model,
  161. prompt_messages=result.prompt_messages,
  162. system_fingerprint=result.system_fingerprint,
  163. delta=LLMResultChunkDelta(
  164. index=0,
  165. message=result.message,
  166. usage=result.usage,
  167. )
  168. )
  169. if tool_calls:
  170. prompt_messages.append(AssistantPromptMessage(
  171. content='',
  172. name='',
  173. tool_calls=[AssistantPromptMessage.ToolCall(
  174. id=tool_call[0],
  175. type='function',
  176. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  177. name=tool_call[1],
  178. arguments=json.dumps(tool_call[2], ensure_ascii=False)
  179. )
  180. ) for tool_call in tool_calls]
  181. ))
  182. # save thought
  183. self.save_agent_thought(
  184. agent_thought=agent_thought,
  185. tool_name=tool_call_names,
  186. tool_input=tool_call_inputs,
  187. thought=response,
  188. observation=None,
  189. answer=response,
  190. messages_ids=[],
  191. llm_usage=current_llm_usage
  192. )
  193. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  194. final_answer += response + '\n'
  195. # update prompt messages
  196. if response.strip():
  197. prompt_messages.append(AssistantPromptMessage(
  198. content=response,
  199. ))
  200. # call tools
  201. tool_responses = []
  202. for tool_call_id, tool_call_name, tool_call_args in tool_calls:
  203. tool_instance = tool_instances.get(tool_call_name)
  204. if not tool_instance:
  205. tool_response = {
  206. "tool_call_id": tool_call_id,
  207. "tool_call_name": tool_call_name,
  208. "tool_response": f"there is not a tool named {tool_call_name}"
  209. }
  210. tool_responses.append(tool_response)
  211. else:
  212. # invoke tool
  213. error_response = None
  214. try:
  215. tool_invoke_message = tool_instance.invoke(
  216. user_id=self.user_id,
  217. tool_parameters=tool_call_args,
  218. )
  219. # transform tool invoke message to get LLM friendly message
  220. tool_invoke_message = self.transform_tool_invoke_messages(tool_invoke_message)
  221. # extract binary data from tool invoke message
  222. binary_files = self.extract_tool_response_binary(tool_invoke_message)
  223. # create message file
  224. message_files = self.create_message_files(binary_files)
  225. # publish files
  226. for message_file, save_as in message_files:
  227. if save_as:
  228. self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as)
  229. # publish message file
  230. self.queue_manager.publish_message_file(message_file, PublishFrom.APPLICATION_MANAGER)
  231. # add message file ids
  232. message_file_ids.append(message_file.id)
  233. except ToolProviderCredentialValidationError as e:
  234. error_response = f"Please check your tool provider credentials"
  235. except (
  236. ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
  237. ) as e:
  238. error_response = f"there is not a tool named {tool_call_name}"
  239. except (
  240. ToolParameterValidationError
  241. ) as e:
  242. error_response = f"tool parameters validation error: {e}, please check your tool parameters"
  243. except ToolInvokeError as e:
  244. error_response = f"tool invoke error: {e}"
  245. except Exception as e:
  246. error_response = f"unknown error: {e}"
  247. if error_response:
  248. observation = error_response
  249. tool_response = {
  250. "tool_call_id": tool_call_id,
  251. "tool_call_name": tool_call_name,
  252. "tool_response": error_response
  253. }
  254. tool_responses.append(tool_response)
  255. else:
  256. observation = self._convert_tool_response_to_str(tool_invoke_message)
  257. tool_response = {
  258. "tool_call_id": tool_call_id,
  259. "tool_call_name": tool_call_name,
  260. "tool_response": observation
  261. }
  262. tool_responses.append(tool_response)
  263. prompt_messages = self.organize_prompt_messages(
  264. prompt_template=prompt_template,
  265. query=None,
  266. tool_call_id=tool_call_id,
  267. tool_call_name=tool_call_name,
  268. tool_response=tool_response['tool_response'],
  269. prompt_messages=prompt_messages,
  270. )
  271. if len(tool_responses) > 0:
  272. # save agent thought
  273. self.save_agent_thought(
  274. agent_thought=agent_thought,
  275. tool_name=None,
  276. tool_input=None,
  277. thought=None,
  278. observation=tool_response['tool_response'],
  279. answer=None,
  280. messages_ids=message_file_ids
  281. )
  282. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  283. # update prompt tool
  284. for prompt_tool in prompt_messages_tools:
  285. self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
  286. iteration_step += 1
  287. self.update_db_variables(self.variables_pool, self.db_variables_pool)
  288. # publish end event
  289. self.queue_manager.publish_message_end(LLMResult(
  290. model=model_instance.model,
  291. prompt_messages=prompt_messages,
  292. message=AssistantPromptMessage(
  293. content=final_answer,
  294. ),
  295. usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
  296. system_fingerprint=''
  297. ), PublishFrom.APPLICATION_MANAGER)
  298. def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
  299. """
  300. Check if there is any tool call in llm result chunk
  301. """
  302. if llm_result_chunk.delta.message.tool_calls:
  303. return True
  304. return False
  305. def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
  306. """
  307. Check if there is any blocking tool call in llm result
  308. """
  309. if llm_result.message.tool_calls:
  310. return True
  311. return False
  312. def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
  313. """
  314. Extract tool calls from llm result chunk
  315. Returns:
  316. List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
  317. """
  318. tool_calls = []
  319. for prompt_message in llm_result_chunk.delta.message.tool_calls:
  320. tool_calls.append((
  321. prompt_message.id,
  322. prompt_message.function.name,
  323. json.loads(prompt_message.function.arguments),
  324. ))
  325. return tool_calls
  326. def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
  327. """
  328. Extract blocking tool calls from llm result
  329. Returns:
  330. List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
  331. """
  332. tool_calls = []
  333. for prompt_message in llm_result.message.tool_calls:
  334. tool_calls.append((
  335. prompt_message.id,
  336. prompt_message.function.name,
  337. json.loads(prompt_message.function.arguments),
  338. ))
  339. return tool_calls
  340. def organize_prompt_messages(self, prompt_template: str,
  341. query: str = None,
  342. tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
  343. prompt_messages: list[PromptMessage] = None
  344. ) -> list[PromptMessage]:
  345. """
  346. Organize prompt messages
  347. """
  348. if not prompt_messages:
  349. prompt_messages = [
  350. SystemPromptMessage(content=prompt_template),
  351. UserPromptMessage(content=query),
  352. ]
  353. else:
  354. if tool_response:
  355. prompt_messages = prompt_messages.copy()
  356. prompt_messages.append(
  357. ToolPromptMessage(
  358. content=tool_response,
  359. tool_call_id=tool_call_id,
  360. name=tool_call_name,
  361. )
  362. )
  363. return prompt_messages