assistant_fc_runner.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. import json
  2. import logging
  3. from typing import Union, Generator, Dict, Any, Tuple, List
  4. from core.model_runtime.entities.message_entities import PromptMessage, UserPromptMessage,\
  5. SystemPromptMessage, AssistantPromptMessage, ToolPromptMessage, PromptMessageTool
  6. from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult, LLMUsage
  7. from core.model_manager import ModelInstance
  8. from core.application_queue_manager import PublishFrom
  9. from core.tools.errors import ToolInvokeError, ToolNotFoundError, \
  10. ToolNotSupportedError, ToolProviderNotFoundError, ToolParamterValidationError, \
  11. ToolProviderCredentialValidationError
  12. from core.features.assistant_base_runner import BaseAssistantApplicationRunner
  13. from models.model import Conversation, Message, MessageAgentThought
  14. logger = logging.getLogger(__name__)
  15. class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
  16. def run(self, model_instance: ModelInstance,
  17. conversation: Conversation,
  18. message: Message,
  19. query: str,
  20. ) -> Generator[LLMResultChunk, None, None]:
  21. """
  22. Run FunctionCall agent application
  23. """
  24. app_orchestration_config = self.app_orchestration_config
  25. prompt_template = self.app_orchestration_config.prompt_template.simple_prompt_template or ''
  26. prompt_messages = self.history_prompt_messages
  27. prompt_messages = self.organize_prompt_messages(
  28. prompt_template=prompt_template,
  29. query=query,
  30. prompt_messages=prompt_messages
  31. )
  32. # convert tools into ModelRuntime Tool format
  33. prompt_messages_tools: List[PromptMessageTool] = []
  34. tool_instances = {}
  35. for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
  36. try:
  37. prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
  38. except Exception:
  39. # api tool may be deleted
  40. continue
  41. # save tool entity
  42. tool_instances[tool.tool_name] = tool_entity
  43. # save prompt tool
  44. prompt_messages_tools.append(prompt_tool)
  45. # convert dataset tools into ModelRuntime Tool format
  46. for dataset_tool in self.dataset_tools:
  47. prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
  48. # save prompt tool
  49. prompt_messages_tools.append(prompt_tool)
  50. # save tool entity
  51. tool_instances[dataset_tool.identity.name] = dataset_tool
  52. iteration_step = 1
  53. max_iteration_steps = min(app_orchestration_config.agent.max_iteration, 5) + 1
  54. # continue to run until there is not any tool call
  55. function_call_state = True
  56. agent_thoughts: List[MessageAgentThought] = []
  57. llm_usage = {
  58. 'usage': None
  59. }
  60. final_answer = ''
  61. def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage):
  62. if not final_llm_usage_dict['usage']:
  63. final_llm_usage_dict['usage'] = usage
  64. else:
  65. llm_usage = final_llm_usage_dict['usage']
  66. llm_usage.prompt_tokens += usage.prompt_tokens
  67. llm_usage.completion_tokens += usage.completion_tokens
  68. llm_usage.prompt_price += usage.prompt_price
  69. llm_usage.completion_price += usage.completion_price
  70. while function_call_state and iteration_step <= max_iteration_steps:
  71. function_call_state = False
  72. if iteration_step == max_iteration_steps:
  73. # the last iteration, remove all tools
  74. prompt_messages_tools = []
  75. message_file_ids = []
  76. agent_thought = self.create_agent_thought(
  77. message_id=message.id,
  78. message='',
  79. tool_name='',
  80. tool_input='',
  81. messages_ids=message_file_ids
  82. )
  83. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  84. # recale llm max tokens
  85. self.recale_llm_max_tokens(self.model_config, prompt_messages)
  86. # invoke model
  87. chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
  88. prompt_messages=prompt_messages,
  89. model_parameters=app_orchestration_config.model_config.parameters,
  90. tools=prompt_messages_tools,
  91. stop=app_orchestration_config.model_config.stop,
  92. stream=True,
  93. user=self.user_id,
  94. callbacks=[],
  95. )
  96. tool_calls: List[Tuple[str, str, Dict[str, Any]]] = []
  97. # save full response
  98. response = ''
  99. # save tool call names and inputs
  100. tool_call_names = ''
  101. tool_call_inputs = ''
  102. current_llm_usage = None
  103. for chunk in chunks:
  104. # check if there is any tool call
  105. if self.check_tool_calls(chunk):
  106. function_call_state = True
  107. tool_calls.extend(self.extract_tool_calls(chunk))
  108. tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
  109. try:
  110. tool_call_inputs = json.dumps({
  111. tool_call[1]: tool_call[2] for tool_call in tool_calls
  112. }, ensure_ascii=False)
  113. except json.JSONDecodeError as e:
  114. # ensure ascii to avoid encoding error
  115. tool_call_inputs = json.dumps({
  116. tool_call[1]: tool_call[2] for tool_call in tool_calls
  117. })
  118. if chunk.delta.message and chunk.delta.message.content:
  119. if isinstance(chunk.delta.message.content, list):
  120. for content in chunk.delta.message.content:
  121. response += content.data
  122. else:
  123. response += chunk.delta.message.content
  124. if chunk.delta.usage:
  125. increase_usage(llm_usage, chunk.delta.usage)
  126. current_llm_usage = chunk.delta.usage
  127. yield chunk
  128. # save thought
  129. self.save_agent_thought(
  130. agent_thought=agent_thought,
  131. tool_name=tool_call_names,
  132. tool_input=tool_call_inputs,
  133. thought=response,
  134. observation=None,
  135. answer=response,
  136. messages_ids=[],
  137. llm_usage=current_llm_usage
  138. )
  139. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  140. final_answer += response + '\n'
  141. # call tools
  142. tool_responses = []
  143. for tool_call_id, tool_call_name, tool_call_args in tool_calls:
  144. tool_instance = tool_instances.get(tool_call_name)
  145. if not tool_instance:
  146. logger.error(f"failed to find tool instance: {tool_call_name}")
  147. tool_response = {
  148. "tool_call_id": tool_call_id,
  149. "tool_call_name": tool_call_name,
  150. "tool_response": f"there is not a tool named {tool_call_name}"
  151. }
  152. tool_responses.append(tool_response)
  153. else:
  154. # invoke tool
  155. error_response = None
  156. try:
  157. tool_invoke_message = tool_instance.invoke(
  158. user_id=self.user_id,
  159. tool_paramters=tool_call_args,
  160. )
  161. # transform tool invoke message to get LLM friendly message
  162. tool_invoke_message = self.transform_tool_invoke_messages(tool_invoke_message)
  163. # extract binary data from tool invoke message
  164. binary_files = self.extract_tool_response_binary(tool_invoke_message)
  165. # create message file
  166. message_files = self.create_message_files(binary_files)
  167. # publish files
  168. for message_file, save_as in message_files:
  169. if save_as:
  170. self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as)
  171. # publish message file
  172. self.queue_manager.publish_message_file(message_file, PublishFrom.APPLICATION_MANAGER)
  173. # add message file ids
  174. message_file_ids.append(message_file.id)
  175. except ToolProviderCredentialValidationError as e:
  176. error_response = f"Plese check your tool provider credentials"
  177. except (
  178. ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
  179. ) as e:
  180. error_response = f"there is not a tool named {tool_call_name}"
  181. except (
  182. ToolParamterValidationError
  183. ) as e:
  184. error_response = f"tool paramters validation error: {e}, please check your tool paramters"
  185. except ToolInvokeError as e:
  186. error_response = f"tool invoke error: {e}"
  187. except Exception as e:
  188. error_response = f"unknown error: {e}"
  189. if error_response:
  190. observation = error_response
  191. logger.error(error_response)
  192. tool_response = {
  193. "tool_call_id": tool_call_id,
  194. "tool_call_name": tool_call_name,
  195. "tool_response": error_response
  196. }
  197. tool_responses.append(tool_response)
  198. else:
  199. observation = self._convert_tool_response_to_str(tool_invoke_message)
  200. tool_response = {
  201. "tool_call_id": tool_call_id,
  202. "tool_call_name": tool_call_name,
  203. "tool_response": observation
  204. }
  205. tool_responses.append(tool_response)
  206. prompt_messages = self.organize_prompt_messages(
  207. prompt_template=prompt_template,
  208. query=None,
  209. tool_call_id=tool_call_id,
  210. tool_call_name=tool_call_name,
  211. tool_response=tool_response['tool_response'],
  212. prompt_messages=prompt_messages,
  213. )
  214. if len(tool_responses) > 0:
  215. # save agent thought
  216. self.save_agent_thought(
  217. agent_thought=agent_thought,
  218. tool_name=None,
  219. tool_input=None,
  220. thought=None,
  221. observation=tool_response['tool_response'],
  222. answer=None,
  223. messages_ids=message_file_ids
  224. )
  225. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  226. # update prompt messages
  227. if response.strip():
  228. prompt_messages.append(AssistantPromptMessage(
  229. content=response,
  230. ))
  231. # update prompt tool
  232. for prompt_tool in prompt_messages_tools:
  233. self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
  234. iteration_step += 1
  235. self.update_db_variables(self.variables_pool, self.db_variables_pool)
  236. # publish end event
  237. self.queue_manager.publish_message_end(LLMResult(
  238. model=model_instance.model,
  239. prompt_messages=prompt_messages,
  240. message=AssistantPromptMessage(
  241. content=final_answer,
  242. ),
  243. usage=llm_usage['usage'],
  244. system_fingerprint=''
  245. ), PublishFrom.APPLICATION_MANAGER)
  246. def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
  247. """
  248. Check if there is any tool call in llm result chunk
  249. """
  250. if llm_result_chunk.delta.message.tool_calls:
  251. return True
  252. return False
  253. def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
  254. """
  255. Extract tool calls from llm result chunk
  256. Returns:
  257. List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
  258. """
  259. tool_calls = []
  260. for prompt_message in llm_result_chunk.delta.message.tool_calls:
  261. tool_calls.append((
  262. prompt_message.id,
  263. prompt_message.function.name,
  264. json.loads(prompt_message.function.arguments),
  265. ))
  266. return tool_calls
  267. def organize_prompt_messages(self, prompt_template: str,
  268. query: str = None,
  269. tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
  270. prompt_messages: list[PromptMessage] = None
  271. ) -> list[PromptMessage]:
  272. """
  273. Organize prompt messages
  274. """
  275. if not prompt_messages:
  276. prompt_messages = [
  277. SystemPromptMessage(content=prompt_template),
  278. UserPromptMessage(content=query),
  279. ]
  280. else:
  281. if tool_response:
  282. prompt_messages = prompt_messages.copy()
  283. prompt_messages.append(
  284. ToolPromptMessage(
  285. content=tool_response,
  286. tool_call_id=tool_call_id,
  287. name=tool_call_name,
  288. )
  289. )
  290. return prompt_messages