assistant_cot_runner.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581
  1. import json
  2. import logging
  3. import re
  4. from typing import Literal, Union, Generator, Dict, List
  5. from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit
  6. from core.application_queue_manager import PublishFrom
  7. from core.model_runtime.utils.encoders import jsonable_encoder
  8. from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage, \
  9. UserPromptMessage, SystemPromptMessage, AssistantPromptMessage
  10. from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta
  11. from core.model_manager import ModelInstance
  12. from core.tools.errors import ToolInvokeError, ToolNotFoundError, \
  13. ToolNotSupportedError, ToolProviderNotFoundError, ToolParamterValidationError, \
  14. ToolProviderCredentialValidationError
  15. from core.features.assistant_base_runner import BaseAssistantApplicationRunner
  16. from models.model import Conversation, Message
  17. class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
  18. def run(self, model_instance: ModelInstance,
  19. conversation: Conversation,
  20. message: Message,
  21. query: str,
  22. ) -> Union[Generator, LLMResult]:
  23. """
  24. Run Cot agent application
  25. """
  26. app_orchestration_config = self.app_orchestration_config
  27. self._repacket_app_orchestration_config(app_orchestration_config)
  28. agent_scratchpad: List[AgentScratchpadUnit] = []
  29. # check model mode
  30. if self.app_orchestration_config.model_config.mode == "completion":
  31. # TODO: stop words
  32. if 'Observation' not in app_orchestration_config.model_config.stop:
  33. app_orchestration_config.model_config.stop.append('Observation')
  34. iteration_step = 1
  35. max_iteration_steps = min(self.app_orchestration_config.agent.max_iteration, 5) + 1
  36. prompt_messages = self.history_prompt_messages
  37. # convert tools into ModelRuntime Tool format
  38. prompt_messages_tools: List[PromptMessageTool] = []
  39. tool_instances = {}
  40. for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
  41. try:
  42. prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
  43. except Exception:
  44. # api tool may be deleted
  45. continue
  46. # save tool entity
  47. tool_instances[tool.tool_name] = tool_entity
  48. # save prompt tool
  49. prompt_messages_tools.append(prompt_tool)
  50. # convert dataset tools into ModelRuntime Tool format
  51. for dataset_tool in self.dataset_tools:
  52. prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
  53. # save prompt tool
  54. prompt_messages_tools.append(prompt_tool)
  55. # save tool entity
  56. tool_instances[dataset_tool.identity.name] = dataset_tool
  57. function_call_state = True
  58. llm_usage = {
  59. 'usage': None
  60. }
  61. final_answer = ''
  62. def increse_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage):
  63. if not final_llm_usage_dict['usage']:
  64. final_llm_usage_dict['usage'] = usage
  65. else:
  66. llm_usage = final_llm_usage_dict['usage']
  67. llm_usage.prompt_tokens += usage.prompt_tokens
  68. llm_usage.completion_tokens += usage.completion_tokens
  69. llm_usage.prompt_price += usage.prompt_price
  70. llm_usage.completion_price += usage.completion_price
  71. while function_call_state and iteration_step <= max_iteration_steps:
  72. # continue to run until there is not any tool call
  73. function_call_state = False
  74. if iteration_step == max_iteration_steps:
  75. # the last iteration, remove all tools
  76. prompt_messages_tools = []
  77. message_file_ids = []
  78. agent_thought = self.create_agent_thought(
  79. message_id=message.id,
  80. message='',
  81. tool_name='',
  82. tool_input='',
  83. messages_ids=message_file_ids
  84. )
  85. if iteration_step > 1:
  86. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  87. # update prompt messages
  88. prompt_messages = self._originze_cot_prompt_messages(
  89. mode=app_orchestration_config.model_config.mode,
  90. prompt_messages=prompt_messages,
  91. tools=prompt_messages_tools,
  92. agent_scratchpad=agent_scratchpad,
  93. agent_prompt_message=app_orchestration_config.agent.prompt,
  94. instruction=app_orchestration_config.prompt_template.simple_prompt_template,
  95. input=query
  96. )
  97. # recale llm max tokens
  98. self.recale_llm_max_tokens(self.model_config, prompt_messages)
  99. # invoke model
  100. llm_result: LLMResult = model_instance.invoke_llm(
  101. prompt_messages=prompt_messages,
  102. model_parameters=app_orchestration_config.model_config.parameters,
  103. tools=[],
  104. stop=app_orchestration_config.model_config.stop,
  105. stream=False,
  106. user=self.user_id,
  107. callbacks=[],
  108. )
  109. # check llm result
  110. if not llm_result:
  111. raise ValueError("failed to invoke llm")
  112. # get scratchpad
  113. scratchpad = self._extract_response_scratchpad(llm_result.message.content)
  114. agent_scratchpad.append(scratchpad)
  115. # get llm usage
  116. if llm_result.usage:
  117. increse_usage(llm_usage, llm_result.usage)
  118. # publish agent thought if it's first iteration
  119. if iteration_step == 1:
  120. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  121. self.save_agent_thought(agent_thought=agent_thought,
  122. tool_name=scratchpad.action.action_name if scratchpad.action else '',
  123. tool_input=scratchpad.action.action_input if scratchpad.action else '',
  124. thought=scratchpad.thought,
  125. observation='',
  126. answer=llm_result.message.content,
  127. messages_ids=[],
  128. llm_usage=llm_result.usage)
  129. if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
  130. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  131. # publish agent thought if it's not empty and there is a action
  132. if scratchpad.thought and scratchpad.action:
  133. # check if final answer
  134. if not scratchpad.action.action_name.lower() == "final answer":
  135. yield LLMResultChunk(
  136. model=model_instance.model,
  137. prompt_messages=prompt_messages,
  138. delta=LLMResultChunkDelta(
  139. index=0,
  140. message=AssistantPromptMessage(
  141. content=scratchpad.thought
  142. ),
  143. usage=llm_result.usage,
  144. ),
  145. system_fingerprint=''
  146. )
  147. if not scratchpad.action:
  148. # failed to extract action, return final answer directly
  149. final_answer = scratchpad.agent_response or ''
  150. else:
  151. if scratchpad.action.action_name.lower() == "final answer":
  152. # action is final answer, return final answer directly
  153. try:
  154. final_answer = scratchpad.action.action_input if \
  155. isinstance(scratchpad.action.action_input, str) else \
  156. json.dumps(scratchpad.action.action_input)
  157. except json.JSONDecodeError:
  158. final_answer = f'{scratchpad.action.action_input}'
  159. else:
  160. function_call_state = True
  161. # action is tool call, invoke tool
  162. tool_call_name = scratchpad.action.action_name
  163. tool_call_args = scratchpad.action.action_input
  164. tool_instance = tool_instances.get(tool_call_name)
  165. if not tool_instance:
  166. answer = f"there is not a tool named {tool_call_name}"
  167. self.save_agent_thought(agent_thought=agent_thought,
  168. tool_name='',
  169. tool_input='',
  170. thought=None,
  171. observation=answer,
  172. answer=answer,
  173. messages_ids=[])
  174. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  175. else:
  176. # invoke tool
  177. error_response = None
  178. try:
  179. tool_response = tool_instance.invoke(
  180. user_id=self.user_id,
  181. tool_paramters=tool_call_args if isinstance(tool_call_args, dict) else json.loads(tool_call_args)
  182. )
  183. # transform tool response to llm friendly response
  184. tool_response = self.transform_tool_invoke_messages(tool_response)
  185. # extract binary data from tool invoke message
  186. binary_files = self.extract_tool_response_binary(tool_response)
  187. # create message file
  188. message_files = self.create_message_files(binary_files)
  189. # publish files
  190. for message_file, save_as in message_files:
  191. if save_as:
  192. self.variables_pool.set_file(tool_name=tool_call_name,
  193. value=message_file.id,
  194. name=save_as)
  195. self.queue_manager.publish_message_file(message_file, PublishFrom.APPLICATION_MANAGER)
  196. message_file_ids = [message_file.id for message_file, _ in message_files]
  197. except ToolProviderCredentialValidationError as e:
  198. error_response = f"Plese check your tool provider credentials"
  199. except (
  200. ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
  201. ) as e:
  202. error_response = f"there is not a tool named {tool_call_name}"
  203. except (
  204. ToolParamterValidationError
  205. ) as e:
  206. error_response = f"tool paramters validation error: {e}, please check your tool paramters"
  207. except ToolInvokeError as e:
  208. error_response = f"tool invoke error: {e}"
  209. except Exception as e:
  210. error_response = f"unknown error: {e}"
  211. if error_response:
  212. observation = error_response
  213. else:
  214. observation = self._convert_tool_response_to_str(tool_response)
  215. # save scratchpad
  216. scratchpad.observation = observation
  217. scratchpad.agent_response = llm_result.message.content
  218. # save agent thought
  219. self.save_agent_thought(
  220. agent_thought=agent_thought,
  221. tool_name=tool_call_name,
  222. tool_input=tool_call_args,
  223. thought=None,
  224. observation=observation,
  225. answer=llm_result.message.content,
  226. messages_ids=message_file_ids,
  227. )
  228. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  229. # update prompt tool message
  230. for prompt_tool in prompt_messages_tools:
  231. self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
  232. iteration_step += 1
  233. yield LLMResultChunk(
  234. model=model_instance.model,
  235. prompt_messages=prompt_messages,
  236. delta=LLMResultChunkDelta(
  237. index=0,
  238. message=AssistantPromptMessage(
  239. content=final_answer
  240. ),
  241. usage=llm_usage['usage']
  242. ),
  243. system_fingerprint=''
  244. )
  245. # save agent thought
  246. self.save_agent_thought(
  247. agent_thought=agent_thought,
  248. tool_name='',
  249. tool_input='',
  250. thought=final_answer,
  251. observation='',
  252. answer=final_answer,
  253. messages_ids=[]
  254. )
  255. self.update_db_variables(self.variables_pool, self.db_variables_pool)
  256. # publish end event
  257. self.queue_manager.publish_message_end(LLMResult(
  258. model=model_instance.model,
  259. prompt_messages=prompt_messages,
  260. message=AssistantPromptMessage(
  261. content=final_answer
  262. ),
  263. usage=llm_usage['usage'],
  264. system_fingerprint=''
  265. ), PublishFrom.APPLICATION_MANAGER)
  266. def _extract_response_scratchpad(self, content: str) -> AgentScratchpadUnit:
  267. """
  268. extract response from llm response
  269. """
  270. def extra_quotes() -> AgentScratchpadUnit:
  271. agent_response = content
  272. # try to extract all quotes
  273. pattern = re.compile(r'```(.*?)```', re.DOTALL)
  274. quotes = pattern.findall(content)
  275. # try to extract action from end to start
  276. for i in range(len(quotes) - 1, 0, -1):
  277. """
  278. 1. use json load to parse action
  279. 2. use plain text `Action: xxx` to parse action
  280. """
  281. try:
  282. action = json.loads(quotes[i].replace('```', ''))
  283. action_name = action.get("action")
  284. action_input = action.get("action_input")
  285. agent_thought = agent_response.replace(quotes[i], '')
  286. if action_name and action_input:
  287. return AgentScratchpadUnit(
  288. agent_response=content,
  289. thought=agent_thought,
  290. action_str=quotes[i],
  291. action=AgentScratchpadUnit.Action(
  292. action_name=action_name,
  293. action_input=action_input,
  294. )
  295. )
  296. except:
  297. # try to parse action from plain text
  298. action_name = re.findall(r'action: (.*)', quotes[i], re.IGNORECASE)
  299. action_input = re.findall(r'action input: (.*)', quotes[i], re.IGNORECASE)
  300. # delete action from agent response
  301. agent_thought = agent_response.replace(quotes[i], '')
  302. # remove extra quotes
  303. agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
  304. # remove Action: xxx from agent thought
  305. agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
  306. if action_name and action_input:
  307. return AgentScratchpadUnit(
  308. agent_response=content,
  309. thought=agent_thought,
  310. action_str=quotes[i],
  311. action=AgentScratchpadUnit.Action(
  312. action_name=action_name[0],
  313. action_input=action_input[0],
  314. )
  315. )
  316. def extra_json():
  317. agent_response = content
  318. # try to extract all json
  319. structures, pair_match_stack = [], []
  320. started_at, end_at = 0, 0
  321. for i in range(len(content)):
  322. if content[i] == '{':
  323. pair_match_stack.append(i)
  324. if len(pair_match_stack) == 1:
  325. started_at = i
  326. elif content[i] == '}':
  327. begin = pair_match_stack.pop()
  328. if not pair_match_stack:
  329. end_at = i + 1
  330. structures.append((content[begin:i+1], (started_at, end_at)))
  331. # handle the last character
  332. if pair_match_stack:
  333. end_at = len(content)
  334. structures.append((content[pair_match_stack[0]:], (started_at, end_at)))
  335. for i in range(len(structures), 0, -1):
  336. try:
  337. json_content, (started_at, end_at) = structures[i - 1]
  338. action = json.loads(json_content)
  339. action_name = action.get("action")
  340. action_input = action.get("action_input")
  341. # delete json content from agent response
  342. agent_thought = agent_response[:started_at] + agent_response[end_at:]
  343. # remove extra quotes like ```(json)*\n\n```
  344. agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
  345. # remove Action: xxx from agent thought
  346. agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
  347. if action_name and action_input:
  348. return AgentScratchpadUnit(
  349. agent_response=content,
  350. thought=agent_thought,
  351. action_str=json_content,
  352. action=AgentScratchpadUnit.Action(
  353. action_name=action_name,
  354. action_input=action_input,
  355. )
  356. )
  357. except:
  358. pass
  359. agent_scratchpad = extra_quotes()
  360. if agent_scratchpad:
  361. return agent_scratchpad
  362. agent_scratchpad = extra_json()
  363. if agent_scratchpad:
  364. return agent_scratchpad
  365. return AgentScratchpadUnit(
  366. agent_response=content,
  367. thought=content,
  368. action_str='',
  369. action=None
  370. )
  371. def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"],
  372. agent_prompt_message: AgentPromptEntity,
  373. ):
  374. """
  375. check chain of thought prompt messages, a standard prompt message is like:
  376. Respond to the human as helpfully and accurately as possible.
  377. {{instruction}}
  378. You have access to the following tools:
  379. {{tools}}
  380. Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
  381. Valid action values: "Final Answer" or {{tool_names}}
  382. Provide only ONE action per $JSON_BLOB, as shown:
  383. ```
  384. {
  385. "action": $TOOL_NAME,
  386. "action_input": $ACTION_INPUT
  387. }
  388. ```
  389. """
  390. # parse agent prompt message
  391. first_prompt = agent_prompt_message.first_prompt
  392. next_iteration = agent_prompt_message.next_iteration
  393. if not isinstance(first_prompt, str) or not isinstance(next_iteration, str):
  394. raise ValueError(f"first_prompt or next_iteration is required in CoT agent mode")
  395. # check instruction, tools, and tool_names slots
  396. if not first_prompt.find("{{instruction}}") >= 0:
  397. raise ValueError("{{instruction}} is required in first_prompt")
  398. if not first_prompt.find("{{tools}}") >= 0:
  399. raise ValueError("{{tools}} is required in first_prompt")
  400. if not first_prompt.find("{{tool_names}}") >= 0:
  401. raise ValueError("{{tool_names}} is required in first_prompt")
  402. if mode == "completion":
  403. if not first_prompt.find("{{query}}") >= 0:
  404. raise ValueError("{{query}} is required in first_prompt")
  405. if not first_prompt.find("{{agent_scratchpad}}") >= 0:
  406. raise ValueError("{{agent_scratchpad}} is required in first_prompt")
  407. if mode == "completion":
  408. if not next_iteration.find("{{observation}}") >= 0:
  409. raise ValueError("{{observation}} is required in next_iteration")
  410. def _convert_strachpad_list_to_str(self, agent_scratchpad: List[AgentScratchpadUnit]) -> str:
  411. """
  412. convert agent scratchpad list to str
  413. """
  414. next_iteration = self.app_orchestration_config.agent.prompt.next_iteration
  415. result = ''
  416. for scratchpad in agent_scratchpad:
  417. result += scratchpad.thought + next_iteration.replace("{{observation}}", scratchpad.observation or '') + "\n"
  418. return result
  419. def _originze_cot_prompt_messages(self, mode: Literal["completion", "chat"],
  420. prompt_messages: List[PromptMessage],
  421. tools: List[PromptMessageTool],
  422. agent_scratchpad: List[AgentScratchpadUnit],
  423. agent_prompt_message: AgentPromptEntity,
  424. instruction: str,
  425. input: str,
  426. ) -> List[PromptMessage]:
  427. """
  428. originze chain of thought prompt messages, a standard prompt message is like:
  429. Respond to the human as helpfully and accurately as possible.
  430. {{instruction}}
  431. You have access to the following tools:
  432. {{tools}}
  433. Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
  434. Valid action values: "Final Answer" or {{tool_names}}
  435. Provide only ONE action per $JSON_BLOB, as shown:
  436. ```
  437. {{{{
  438. "action": $TOOL_NAME,
  439. "action_input": $ACTION_INPUT
  440. }}}}
  441. ```
  442. """
  443. self._check_cot_prompt_messages(mode, agent_prompt_message)
  444. # parse agent prompt message
  445. first_prompt = agent_prompt_message.first_prompt
  446. # parse tools
  447. tools_str = self._jsonify_tool_prompt_messages(tools)
  448. # parse tools name
  449. tool_names = '"' + '","'.join([tool.name for tool in tools]) + '"'
  450. # get system message
  451. system_message = first_prompt.replace("{{instruction}}", instruction) \
  452. .replace("{{tools}}", tools_str) \
  453. .replace("{{tool_names}}", tool_names)
  454. # originze prompt messages
  455. if mode == "chat":
  456. # override system message
  457. overrided = False
  458. prompt_messages = prompt_messages.copy()
  459. for prompt_message in prompt_messages:
  460. if isinstance(prompt_message, SystemPromptMessage):
  461. prompt_message.content = system_message
  462. overrided = True
  463. break
  464. if not overrided:
  465. prompt_messages.insert(0, SystemPromptMessage(
  466. content=system_message,
  467. ))
  468. # add assistant message
  469. if len(agent_scratchpad) > 0:
  470. prompt_messages.append(AssistantPromptMessage(
  471. content=(agent_scratchpad[-1].thought or '')
  472. ))
  473. # add user message
  474. if len(agent_scratchpad) > 0:
  475. prompt_messages.append(UserPromptMessage(
  476. content=(agent_scratchpad[-1].observation or ''),
  477. ))
  478. return prompt_messages
  479. elif mode == "completion":
  480. # parse agent scratchpad
  481. agent_scratchpad_str = self._convert_strachpad_list_to_str(agent_scratchpad)
  482. # parse prompt messages
  483. return [UserPromptMessage(
  484. content=first_prompt.replace("{{instruction}}", instruction)
  485. .replace("{{tools}}", tools_str)
  486. .replace("{{tool_names}}", tool_names)
  487. .replace("{{query}}", input)
  488. .replace("{{agent_scratchpad}}", agent_scratchpad_str),
  489. )]
  490. else:
  491. raise ValueError(f"mode {mode} is not supported")
  492. def _jsonify_tool_prompt_messages(self, tools: list[PromptMessageTool]) -> str:
  493. """
  494. jsonify tool prompt messages
  495. """
  496. tools = jsonable_encoder(tools)
  497. try:
  498. return json.dumps(tools, ensure_ascii=False)
  499. except json.JSONDecodeError:
  500. return json.dumps(tools)