assistant_cot_runner.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606
  1. import json
  2. import re
  3. from typing import Dict, Generator, List, Literal, Union
  4. from core.application_queue_manager import PublishFrom
  5. from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit
  6. from core.features.assistant_base_runner import BaseAssistantApplicationRunner
  7. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
  8. from core.model_runtime.entities.message_entities import (
  9. AssistantPromptMessage,
  10. PromptMessage,
  11. PromptMessageTool,
  12. SystemPromptMessage,
  13. UserPromptMessage,
  14. )
  15. from core.model_runtime.utils.encoders import jsonable_encoder
  16. from core.tools.errors import (
  17. ToolInvokeError,
  18. ToolNotFoundError,
  19. ToolNotSupportedError,
  20. ToolParameterValidationError,
  21. ToolProviderCredentialValidationError,
  22. ToolProviderNotFoundError,
  23. )
  24. from models.model import Conversation, Message
  25. class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
  26. def run(self, conversation: Conversation,
  27. message: Message,
  28. query: str,
  29. inputs: Dict[str, str],
  30. ) -> Union[Generator, LLMResult]:
  31. """
  32. Run Cot agent application
  33. """
  34. app_orchestration_config = self.app_orchestration_config
  35. self._repack_app_orchestration_config(app_orchestration_config)
  36. agent_scratchpad: List[AgentScratchpadUnit] = []
  37. # check model mode
  38. if self.app_orchestration_config.model_config.mode == "completion":
  39. # TODO: stop words
  40. if 'Observation' not in app_orchestration_config.model_config.stop:
  41. app_orchestration_config.model_config.stop.append('Observation')
  42. # override inputs
  43. inputs = inputs or {}
  44. instruction = self.app_orchestration_config.prompt_template.simple_prompt_template
  45. instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
  46. iteration_step = 1
  47. max_iteration_steps = min(self.app_orchestration_config.agent.max_iteration, 5) + 1
  48. prompt_messages = self.history_prompt_messages
  49. # convert tools into ModelRuntime Tool format
  50. prompt_messages_tools: List[PromptMessageTool] = []
  51. tool_instances = {}
  52. for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
  53. try:
  54. prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
  55. except Exception:
  56. # api tool may be deleted
  57. continue
  58. # save tool entity
  59. tool_instances[tool.tool_name] = tool_entity
  60. # save prompt tool
  61. prompt_messages_tools.append(prompt_tool)
  62. # convert dataset tools into ModelRuntime Tool format
  63. for dataset_tool in self.dataset_tools:
  64. prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
  65. # save prompt tool
  66. prompt_messages_tools.append(prompt_tool)
  67. # save tool entity
  68. tool_instances[dataset_tool.identity.name] = dataset_tool
  69. function_call_state = True
  70. llm_usage = {
  71. 'usage': None
  72. }
  73. final_answer = ''
  74. def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage):
  75. if not final_llm_usage_dict['usage']:
  76. final_llm_usage_dict['usage'] = usage
  77. else:
  78. llm_usage = final_llm_usage_dict['usage']
  79. llm_usage.prompt_tokens += usage.prompt_tokens
  80. llm_usage.completion_tokens += usage.completion_tokens
  81. llm_usage.prompt_price += usage.prompt_price
  82. llm_usage.completion_price += usage.completion_price
  83. model_instance = self.model_instance
  84. while function_call_state and iteration_step <= max_iteration_steps:
  85. # continue to run until there is not any tool call
  86. function_call_state = False
  87. if iteration_step == max_iteration_steps:
  88. # the last iteration, remove all tools
  89. prompt_messages_tools = []
  90. message_file_ids = []
  91. agent_thought = self.create_agent_thought(
  92. message_id=message.id,
  93. message='',
  94. tool_name='',
  95. tool_input='',
  96. messages_ids=message_file_ids
  97. )
  98. if iteration_step > 1:
  99. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  100. # update prompt messages
  101. prompt_messages = self._organize_cot_prompt_messages(
  102. mode=app_orchestration_config.model_config.mode,
  103. prompt_messages=prompt_messages,
  104. tools=prompt_messages_tools,
  105. agent_scratchpad=agent_scratchpad,
  106. agent_prompt_message=app_orchestration_config.agent.prompt,
  107. instruction=instruction,
  108. input=query
  109. )
  110. # recale llm max tokens
  111. self.recale_llm_max_tokens(self.model_config, prompt_messages)
  112. # invoke model
  113. llm_result: LLMResult = model_instance.invoke_llm(
  114. prompt_messages=prompt_messages,
  115. model_parameters=app_orchestration_config.model_config.parameters,
  116. tools=[],
  117. stop=app_orchestration_config.model_config.stop,
  118. stream=False,
  119. user=self.user_id,
  120. callbacks=[],
  121. )
  122. # check llm result
  123. if not llm_result:
  124. raise ValueError("failed to invoke llm")
  125. # get scratchpad
  126. scratchpad = self._extract_response_scratchpad(llm_result.message.content)
  127. agent_scratchpad.append(scratchpad)
  128. # get llm usage
  129. if llm_result.usage:
  130. increase_usage(llm_usage, llm_result.usage)
  131. # publish agent thought if it's first iteration
  132. if iteration_step == 1:
  133. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  134. self.save_agent_thought(agent_thought=agent_thought,
  135. tool_name=scratchpad.action.action_name if scratchpad.action else '',
  136. tool_input=scratchpad.action.action_input if scratchpad.action else '',
  137. thought=scratchpad.thought,
  138. observation='',
  139. answer=llm_result.message.content,
  140. messages_ids=[],
  141. llm_usage=llm_result.usage)
  142. if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
  143. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  144. # publish agent thought if it's not empty and there is a action
  145. if scratchpad.thought and scratchpad.action:
  146. # check if final answer
  147. if not scratchpad.action.action_name.lower() == "final answer":
  148. yield LLMResultChunk(
  149. model=model_instance.model,
  150. prompt_messages=prompt_messages,
  151. delta=LLMResultChunkDelta(
  152. index=0,
  153. message=AssistantPromptMessage(
  154. content=scratchpad.thought
  155. ),
  156. usage=llm_result.usage,
  157. ),
  158. system_fingerprint=''
  159. )
  160. if not scratchpad.action:
  161. # failed to extract action, return final answer directly
  162. final_answer = scratchpad.agent_response or ''
  163. else:
  164. if scratchpad.action.action_name.lower() == "final answer":
  165. # action is final answer, return final answer directly
  166. try:
  167. final_answer = scratchpad.action.action_input if \
  168. isinstance(scratchpad.action.action_input, str) else \
  169. json.dumps(scratchpad.action.action_input)
  170. except json.JSONDecodeError:
  171. final_answer = f'{scratchpad.action.action_input}'
  172. else:
  173. function_call_state = True
  174. # action is tool call, invoke tool
  175. tool_call_name = scratchpad.action.action_name
  176. tool_call_args = scratchpad.action.action_input
  177. tool_instance = tool_instances.get(tool_call_name)
  178. if not tool_instance:
  179. answer = f"there is not a tool named {tool_call_name}"
  180. self.save_agent_thought(agent_thought=agent_thought,
  181. tool_name='',
  182. tool_input='',
  183. thought=None,
  184. observation=answer,
  185. answer=answer,
  186. messages_ids=[])
  187. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  188. else:
  189. # invoke tool
  190. error_response = None
  191. try:
  192. tool_response = tool_instance.invoke(
  193. user_id=self.user_id,
  194. tool_parameters=tool_call_args if isinstance(tool_call_args, dict) else json.loads(tool_call_args)
  195. )
  196. # transform tool response to llm friendly response
  197. tool_response = self.transform_tool_invoke_messages(tool_response)
  198. # extract binary data from tool invoke message
  199. binary_files = self.extract_tool_response_binary(tool_response)
  200. # create message file
  201. message_files = self.create_message_files(binary_files)
  202. # publish files
  203. for message_file, save_as in message_files:
  204. if save_as:
  205. self.variables_pool.set_file(tool_name=tool_call_name,
  206. value=message_file.id,
  207. name=save_as)
  208. self.queue_manager.publish_message_file(message_file, PublishFrom.APPLICATION_MANAGER)
  209. message_file_ids = [message_file.id for message_file, _ in message_files]
  210. except ToolProviderCredentialValidationError as e:
  211. error_response = "Please check your tool provider credentials"
  212. except (
  213. ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
  214. ) as e:
  215. error_response = f"there is not a tool named {tool_call_name}"
  216. except (
  217. ToolParameterValidationError
  218. ) as e:
  219. error_response = f"tool parameters validation error: {e}, please check your tool parameters"
  220. except ToolInvokeError as e:
  221. error_response = f"tool invoke error: {e}"
  222. except Exception as e:
  223. error_response = f"unknown error: {e}"
  224. if error_response:
  225. observation = error_response
  226. else:
  227. observation = self._convert_tool_response_to_str(tool_response)
  228. # save scratchpad
  229. scratchpad.observation = observation
  230. scratchpad.agent_response = llm_result.message.content
  231. # save agent thought
  232. self.save_agent_thought(
  233. agent_thought=agent_thought,
  234. tool_name=tool_call_name,
  235. tool_input=tool_call_args,
  236. thought=None,
  237. observation=observation,
  238. answer=llm_result.message.content,
  239. messages_ids=message_file_ids,
  240. )
  241. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  242. # update prompt tool message
  243. for prompt_tool in prompt_messages_tools:
  244. self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
  245. iteration_step += 1
  246. yield LLMResultChunk(
  247. model=model_instance.model,
  248. prompt_messages=prompt_messages,
  249. delta=LLMResultChunkDelta(
  250. index=0,
  251. message=AssistantPromptMessage(
  252. content=final_answer
  253. ),
  254. usage=llm_usage['usage']
  255. ),
  256. system_fingerprint=''
  257. )
  258. # save agent thought
  259. self.save_agent_thought(
  260. agent_thought=agent_thought,
  261. tool_name='',
  262. tool_input='',
  263. thought=final_answer,
  264. observation='',
  265. answer=final_answer,
  266. messages_ids=[]
  267. )
  268. self.update_db_variables(self.variables_pool, self.db_variables_pool)
  269. # publish end event
  270. self.queue_manager.publish_message_end(LLMResult(
  271. model=model_instance.model,
  272. prompt_messages=prompt_messages,
  273. message=AssistantPromptMessage(
  274. content=final_answer
  275. ),
  276. usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
  277. system_fingerprint=''
  278. ), PublishFrom.APPLICATION_MANAGER)
  279. def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
  280. """
  281. fill in inputs from external data tools
  282. """
  283. for key, value in inputs.items():
  284. try:
  285. instruction = instruction.replace(f'{{{{{key}}}}}', str(value))
  286. except Exception as e:
  287. continue
  288. return instruction
  289. def _extract_response_scratchpad(self, content: str) -> AgentScratchpadUnit:
  290. """
  291. extract response from llm response
  292. """
  293. def extra_quotes() -> AgentScratchpadUnit:
  294. agent_response = content
  295. # try to extract all quotes
  296. pattern = re.compile(r'```(.*?)```', re.DOTALL)
  297. quotes = pattern.findall(content)
  298. # try to extract action from end to start
  299. for i in range(len(quotes) - 1, 0, -1):
  300. """
  301. 1. use json load to parse action
  302. 2. use plain text `Action: xxx` to parse action
  303. """
  304. try:
  305. action = json.loads(quotes[i].replace('```', ''))
  306. action_name = action.get("action")
  307. action_input = action.get("action_input")
  308. agent_thought = agent_response.replace(quotes[i], '')
  309. if action_name and action_input:
  310. return AgentScratchpadUnit(
  311. agent_response=content,
  312. thought=agent_thought,
  313. action_str=quotes[i],
  314. action=AgentScratchpadUnit.Action(
  315. action_name=action_name,
  316. action_input=action_input,
  317. )
  318. )
  319. except:
  320. # try to parse action from plain text
  321. action_name = re.findall(r'action: (.*)', quotes[i], re.IGNORECASE)
  322. action_input = re.findall(r'action input: (.*)', quotes[i], re.IGNORECASE)
  323. # delete action from agent response
  324. agent_thought = agent_response.replace(quotes[i], '')
  325. # remove extra quotes
  326. agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
  327. # remove Action: xxx from agent thought
  328. agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
  329. if action_name and action_input:
  330. return AgentScratchpadUnit(
  331. agent_response=content,
  332. thought=agent_thought,
  333. action_str=quotes[i],
  334. action=AgentScratchpadUnit.Action(
  335. action_name=action_name[0],
  336. action_input=action_input[0],
  337. )
  338. )
  339. def extra_json():
  340. agent_response = content
  341. # try to extract all json
  342. structures, pair_match_stack = [], []
  343. started_at, end_at = 0, 0
  344. for i in range(len(content)):
  345. if content[i] == '{':
  346. pair_match_stack.append(i)
  347. if len(pair_match_stack) == 1:
  348. started_at = i
  349. elif content[i] == '}':
  350. begin = pair_match_stack.pop()
  351. if not pair_match_stack:
  352. end_at = i + 1
  353. structures.append((content[begin:i+1], (started_at, end_at)))
  354. # handle the last character
  355. if pair_match_stack:
  356. end_at = len(content)
  357. structures.append((content[pair_match_stack[0]:], (started_at, end_at)))
  358. for i in range(len(structures), 0, -1):
  359. try:
  360. json_content, (started_at, end_at) = structures[i - 1]
  361. action = json.loads(json_content)
  362. action_name = action.get("action")
  363. action_input = action.get("action_input")
  364. # delete json content from agent response
  365. agent_thought = agent_response[:started_at] + agent_response[end_at:]
  366. # remove extra quotes like ```(json)*\n\n```
  367. agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
  368. # remove Action: xxx from agent thought
  369. agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
  370. if action_name and action_input is not None:
  371. return AgentScratchpadUnit(
  372. agent_response=content,
  373. thought=agent_thought,
  374. action_str=json_content,
  375. action=AgentScratchpadUnit.Action(
  376. action_name=action_name,
  377. action_input=action_input,
  378. )
  379. )
  380. except:
  381. pass
  382. agent_scratchpad = extra_quotes()
  383. if agent_scratchpad:
  384. return agent_scratchpad
  385. agent_scratchpad = extra_json()
  386. if agent_scratchpad:
  387. return agent_scratchpad
  388. return AgentScratchpadUnit(
  389. agent_response=content,
  390. thought=content,
  391. action_str='',
  392. action=None
  393. )
  394. def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"],
  395. agent_prompt_message: AgentPromptEntity,
  396. ):
  397. """
  398. check chain of thought prompt messages, a standard prompt message is like:
  399. Respond to the human as helpfully and accurately as possible.
  400. {{instruction}}
  401. You have access to the following tools:
  402. {{tools}}
  403. Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
  404. Valid action values: "Final Answer" or {{tool_names}}
  405. Provide only ONE action per $JSON_BLOB, as shown:
  406. ```
  407. {
  408. "action": $TOOL_NAME,
  409. "action_input": $ACTION_INPUT
  410. }
  411. ```
  412. """
  413. # parse agent prompt message
  414. first_prompt = agent_prompt_message.first_prompt
  415. next_iteration = agent_prompt_message.next_iteration
  416. if not isinstance(first_prompt, str) or not isinstance(next_iteration, str):
  417. raise ValueError("first_prompt or next_iteration is required in CoT agent mode")
  418. # check instruction, tools, and tool_names slots
  419. if not first_prompt.find("{{instruction}}") >= 0:
  420. raise ValueError("{{instruction}} is required in first_prompt")
  421. if not first_prompt.find("{{tools}}") >= 0:
  422. raise ValueError("{{tools}} is required in first_prompt")
  423. if not first_prompt.find("{{tool_names}}") >= 0:
  424. raise ValueError("{{tool_names}} is required in first_prompt")
  425. if mode == "completion":
  426. if not first_prompt.find("{{query}}") >= 0:
  427. raise ValueError("{{query}} is required in first_prompt")
  428. if not first_prompt.find("{{agent_scratchpad}}") >= 0:
  429. raise ValueError("{{agent_scratchpad}} is required in first_prompt")
  430. if mode == "completion":
  431. if not next_iteration.find("{{observation}}") >= 0:
  432. raise ValueError("{{observation}} is required in next_iteration")
  433. def _convert_scratchpad_list_to_str(self, agent_scratchpad: List[AgentScratchpadUnit]) -> str:
  434. """
  435. convert agent scratchpad list to str
  436. """
  437. next_iteration = self.app_orchestration_config.agent.prompt.next_iteration
  438. result = ''
  439. for scratchpad in agent_scratchpad:
  440. result += scratchpad.thought + next_iteration.replace("{{observation}}", scratchpad.observation or '') + "\n"
  441. return result
  442. def _organize_cot_prompt_messages(self, mode: Literal["completion", "chat"],
  443. prompt_messages: List[PromptMessage],
  444. tools: List[PromptMessageTool],
  445. agent_scratchpad: List[AgentScratchpadUnit],
  446. agent_prompt_message: AgentPromptEntity,
  447. instruction: str,
  448. input: str,
  449. ) -> List[PromptMessage]:
  450. """
  451. organize chain of thought prompt messages, a standard prompt message is like:
  452. Respond to the human as helpfully and accurately as possible.
  453. {{instruction}}
  454. You have access to the following tools:
  455. {{tools}}
  456. Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
  457. Valid action values: "Final Answer" or {{tool_names}}
  458. Provide only ONE action per $JSON_BLOB, as shown:
  459. ```
  460. {{{{
  461. "action": $TOOL_NAME,
  462. "action_input": $ACTION_INPUT
  463. }}}}
  464. ```
  465. """
  466. self._check_cot_prompt_messages(mode, agent_prompt_message)
  467. # parse agent prompt message
  468. first_prompt = agent_prompt_message.first_prompt
  469. # parse tools
  470. tools_str = self._jsonify_tool_prompt_messages(tools)
  471. # parse tools name
  472. tool_names = '"' + '","'.join([tool.name for tool in tools]) + '"'
  473. # get system message
  474. system_message = first_prompt.replace("{{instruction}}", instruction) \
  475. .replace("{{tools}}", tools_str) \
  476. .replace("{{tool_names}}", tool_names)
  477. # organize prompt messages
  478. if mode == "chat":
  479. # override system message
  480. overrided = False
  481. prompt_messages = prompt_messages.copy()
  482. for prompt_message in prompt_messages:
  483. if isinstance(prompt_message, SystemPromptMessage):
  484. prompt_message.content = system_message
  485. overrided = True
  486. break
  487. if not overrided:
  488. prompt_messages.insert(0, SystemPromptMessage(
  489. content=system_message,
  490. ))
  491. # add assistant message
  492. if len(agent_scratchpad) > 0:
  493. prompt_messages.append(AssistantPromptMessage(
  494. content=(agent_scratchpad[-1].thought or '')
  495. ))
  496. # add user message
  497. if len(agent_scratchpad) > 0:
  498. prompt_messages.append(UserPromptMessage(
  499. content=(agent_scratchpad[-1].observation or ''),
  500. ))
  501. return prompt_messages
  502. elif mode == "completion":
  503. # parse agent scratchpad
  504. agent_scratchpad_str = self._convert_scratchpad_list_to_str(agent_scratchpad)
  505. # parse prompt messages
  506. return [UserPromptMessage(
  507. content=first_prompt.replace("{{instruction}}", instruction)
  508. .replace("{{tools}}", tools_str)
  509. .replace("{{tool_names}}", tool_names)
  510. .replace("{{query}}", input)
  511. .replace("{{agent_scratchpad}}", agent_scratchpad_str),
  512. )]
  513. else:
  514. raise ValueError(f"mode {mode} is not supported")
  515. def _jsonify_tool_prompt_messages(self, tools: list[PromptMessageTool]) -> str:
  516. """
  517. jsonify tool prompt messages
  518. """
  519. tools = jsonable_encoder(tools)
  520. try:
  521. return json.dumps(tools, ensure_ascii=False)
  522. except json.JSONDecodeError:
  523. return json.dumps(tools)