assistant_cot_runner.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651
  1. import json
  2. import re
  3. from collections.abc import Generator
  4. from typing import Literal, Union
  5. from core.application_queue_manager import PublishFrom
  6. from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit
  7. from core.features.assistant_base_runner import BaseAssistantApplicationRunner
  8. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
  9. from core.model_runtime.entities.message_entities import (
  10. AssistantPromptMessage,
  11. PromptMessage,
  12. PromptMessageTool,
  13. SystemPromptMessage,
  14. ToolPromptMessage,
  15. UserPromptMessage,
  16. )
  17. from core.model_runtime.utils.encoders import jsonable_encoder
  18. from core.tools.errors import (
  19. ToolInvokeError,
  20. ToolNotFoundError,
  21. ToolNotSupportedError,
  22. ToolParameterValidationError,
  23. ToolProviderCredentialValidationError,
  24. ToolProviderNotFoundError,
  25. )
  26. from models.model import Conversation, Message
  27. class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
  28. def run(self, conversation: Conversation,
  29. message: Message,
  30. query: str,
  31. inputs: dict[str, str],
  32. ) -> Union[Generator, LLMResult]:
  33. """
  34. Run Cot agent application
  35. """
  36. app_orchestration_config = self.app_orchestration_config
  37. self._repack_app_orchestration_config(app_orchestration_config)
  38. agent_scratchpad: list[AgentScratchpadUnit] = []
  39. self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages)
  40. # check model mode
  41. if self.app_orchestration_config.model_config.mode == "completion":
  42. # TODO: stop words
  43. if 'Observation' not in app_orchestration_config.model_config.stop:
  44. app_orchestration_config.model_config.stop.append('Observation')
  45. # override inputs
  46. inputs = inputs or {}
  47. instruction = self.app_orchestration_config.prompt_template.simple_prompt_template
  48. instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
  49. iteration_step = 1
  50. max_iteration_steps = min(self.app_orchestration_config.agent.max_iteration, 5) + 1
  51. prompt_messages = self.history_prompt_messages
  52. # convert tools into ModelRuntime Tool format
  53. prompt_messages_tools: list[PromptMessageTool] = []
  54. tool_instances = {}
  55. for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
  56. try:
  57. prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
  58. except Exception:
  59. # api tool may be deleted
  60. continue
  61. # save tool entity
  62. tool_instances[tool.tool_name] = tool_entity
  63. # save prompt tool
  64. prompt_messages_tools.append(prompt_tool)
  65. # convert dataset tools into ModelRuntime Tool format
  66. for dataset_tool in self.dataset_tools:
  67. prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
  68. # save prompt tool
  69. prompt_messages_tools.append(prompt_tool)
  70. # save tool entity
  71. tool_instances[dataset_tool.identity.name] = dataset_tool
  72. function_call_state = True
  73. llm_usage = {
  74. 'usage': None
  75. }
  76. final_answer = ''
  77. def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
  78. if not final_llm_usage_dict['usage']:
  79. final_llm_usage_dict['usage'] = usage
  80. else:
  81. llm_usage = final_llm_usage_dict['usage']
  82. llm_usage.prompt_tokens += usage.prompt_tokens
  83. llm_usage.completion_tokens += usage.completion_tokens
  84. llm_usage.prompt_price += usage.prompt_price
  85. llm_usage.completion_price += usage.completion_price
  86. model_instance = self.model_instance
  87. while function_call_state and iteration_step <= max_iteration_steps:
  88. # continue to run until there is not any tool call
  89. function_call_state = False
  90. if iteration_step == max_iteration_steps:
  91. # the last iteration, remove all tools
  92. prompt_messages_tools = []
  93. message_file_ids = []
  94. agent_thought = self.create_agent_thought(
  95. message_id=message.id,
  96. message='',
  97. tool_name='',
  98. tool_input='',
  99. messages_ids=message_file_ids
  100. )
  101. if iteration_step > 1:
  102. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  103. # update prompt messages
  104. prompt_messages = self._organize_cot_prompt_messages(
  105. mode=app_orchestration_config.model_config.mode,
  106. prompt_messages=prompt_messages,
  107. tools=prompt_messages_tools,
  108. agent_scratchpad=agent_scratchpad,
  109. agent_prompt_message=app_orchestration_config.agent.prompt,
  110. instruction=instruction,
  111. input=query
  112. )
  113. # recale llm max tokens
  114. self.recale_llm_max_tokens(self.model_config, prompt_messages)
  115. # invoke model
  116. chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
  117. prompt_messages=prompt_messages,
  118. model_parameters=app_orchestration_config.model_config.parameters,
  119. tools=[],
  120. stop=app_orchestration_config.model_config.stop,
  121. stream=True,
  122. user=self.user_id,
  123. callbacks=[],
  124. )
  125. # check llm result
  126. if not chunks:
  127. raise ValueError("failed to invoke llm")
  128. usage_dict = {}
  129. react_chunks = self._handle_stream_react(chunks, usage_dict)
  130. scratchpad = AgentScratchpadUnit(
  131. agent_response='',
  132. thought='',
  133. action_str='',
  134. observation='',
  135. action=None
  136. )
  137. # publish agent thought if it's first iteration
  138. if iteration_step == 1:
  139. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  140. for chunk in react_chunks:
  141. if isinstance(chunk, dict):
  142. scratchpad.agent_response += json.dumps(chunk)
  143. try:
  144. if scratchpad.action:
  145. raise Exception("")
  146. scratchpad.action_str = json.dumps(chunk)
  147. scratchpad.action = AgentScratchpadUnit.Action(
  148. action_name=chunk['action'],
  149. action_input=chunk['action_input']
  150. )
  151. except:
  152. scratchpad.thought += json.dumps(chunk)
  153. yield LLMResultChunk(
  154. model=self.model_config.model,
  155. prompt_messages=prompt_messages,
  156. system_fingerprint='',
  157. delta=LLMResultChunkDelta(
  158. index=0,
  159. message=AssistantPromptMessage(
  160. content=json.dumps(chunk)
  161. ),
  162. usage=None
  163. )
  164. )
  165. else:
  166. scratchpad.agent_response += chunk
  167. scratchpad.thought += chunk
  168. yield LLMResultChunk(
  169. model=self.model_config.model,
  170. prompt_messages=prompt_messages,
  171. system_fingerprint='',
  172. delta=LLMResultChunkDelta(
  173. index=0,
  174. message=AssistantPromptMessage(
  175. content=chunk
  176. ),
  177. usage=None
  178. )
  179. )
  180. agent_scratchpad.append(scratchpad)
  181. # get llm usage
  182. if 'usage' in usage_dict:
  183. increase_usage(llm_usage, usage_dict['usage'])
  184. else:
  185. usage_dict['usage'] = LLMUsage.empty_usage()
  186. self.save_agent_thought(agent_thought=agent_thought,
  187. tool_name=scratchpad.action.action_name if scratchpad.action else '',
  188. tool_input=scratchpad.action.action_input if scratchpad.action else '',
  189. thought=scratchpad.thought,
  190. observation='',
  191. answer=scratchpad.agent_response,
  192. messages_ids=[],
  193. llm_usage=usage_dict['usage'])
  194. if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
  195. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  196. if not scratchpad.action:
  197. # failed to extract action, return final answer directly
  198. final_answer = scratchpad.agent_response or ''
  199. else:
  200. if scratchpad.action.action_name.lower() == "final answer":
  201. # action is final answer, return final answer directly
  202. try:
  203. final_answer = scratchpad.action.action_input if \
  204. isinstance(scratchpad.action.action_input, str) else \
  205. json.dumps(scratchpad.action.action_input)
  206. except json.JSONDecodeError:
  207. final_answer = f'{scratchpad.action.action_input}'
  208. else:
  209. function_call_state = True
  210. # action is tool call, invoke tool
  211. tool_call_name = scratchpad.action.action_name
  212. tool_call_args = scratchpad.action.action_input
  213. tool_instance = tool_instances.get(tool_call_name)
  214. if not tool_instance:
  215. answer = f"there is not a tool named {tool_call_name}"
  216. self.save_agent_thought(agent_thought=agent_thought,
  217. tool_name='',
  218. tool_input='',
  219. thought=None,
  220. observation=answer,
  221. answer=answer,
  222. messages_ids=[])
  223. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  224. else:
  225. # invoke tool
  226. error_response = None
  227. try:
  228. tool_response = tool_instance.invoke(
  229. user_id=self.user_id,
  230. tool_parameters=tool_call_args if isinstance(tool_call_args, dict) else json.loads(tool_call_args)
  231. )
  232. # transform tool response to llm friendly response
  233. tool_response = self.transform_tool_invoke_messages(tool_response)
  234. # extract binary data from tool invoke message
  235. binary_files = self.extract_tool_response_binary(tool_response)
  236. # create message file
  237. message_files = self.create_message_files(binary_files)
  238. # publish files
  239. for message_file, save_as in message_files:
  240. if save_as:
  241. self.variables_pool.set_file(tool_name=tool_call_name,
  242. value=message_file.id,
  243. name=save_as)
  244. self.queue_manager.publish_message_file(message_file, PublishFrom.APPLICATION_MANAGER)
  245. message_file_ids = [message_file.id for message_file, _ in message_files]
  246. except ToolProviderCredentialValidationError as e:
  247. error_response = "Please check your tool provider credentials"
  248. except (
  249. ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
  250. ) as e:
  251. error_response = f"there is not a tool named {tool_call_name}"
  252. except (
  253. ToolParameterValidationError
  254. ) as e:
  255. error_response = f"tool parameters validation error: {e}, please check your tool parameters"
  256. except ToolInvokeError as e:
  257. error_response = f"tool invoke error: {e}"
  258. except Exception as e:
  259. error_response = f"unknown error: {e}"
  260. if error_response:
  261. observation = error_response
  262. else:
  263. observation = self._convert_tool_response_to_str(tool_response)
  264. # save scratchpad
  265. scratchpad.observation = observation
  266. # save agent thought
  267. self.save_agent_thought(
  268. agent_thought=agent_thought,
  269. tool_name=tool_call_name,
  270. tool_input=tool_call_args,
  271. thought=None,
  272. observation=observation,
  273. answer=scratchpad.agent_response,
  274. messages_ids=message_file_ids,
  275. )
  276. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  277. # update prompt tool message
  278. for prompt_tool in prompt_messages_tools:
  279. self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
  280. iteration_step += 1
  281. yield LLMResultChunk(
  282. model=model_instance.model,
  283. prompt_messages=prompt_messages,
  284. delta=LLMResultChunkDelta(
  285. index=0,
  286. message=AssistantPromptMessage(
  287. content=final_answer
  288. ),
  289. usage=llm_usage['usage']
  290. ),
  291. system_fingerprint=''
  292. )
  293. # save agent thought
  294. self.save_agent_thought(
  295. agent_thought=agent_thought,
  296. tool_name='',
  297. tool_input='',
  298. thought=final_answer,
  299. observation='',
  300. answer=final_answer,
  301. messages_ids=[]
  302. )
  303. self.update_db_variables(self.variables_pool, self.db_variables_pool)
  304. # publish end event
  305. self.queue_manager.publish_message_end(LLMResult(
  306. model=model_instance.model,
  307. prompt_messages=prompt_messages,
  308. message=AssistantPromptMessage(
  309. content=final_answer
  310. ),
  311. usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
  312. system_fingerprint=''
  313. ), PublishFrom.APPLICATION_MANAGER)
  314. def _handle_stream_react(self, llm_response: Generator[LLMResultChunk, None, None], usage: dict) \
  315. -> Generator[Union[str, dict], None, None]:
  316. def parse_json(json_str):
  317. try:
  318. return json.loads(json_str.strip())
  319. except:
  320. return json_str
  321. def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
  322. code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL)
  323. if not code_blocks:
  324. return
  325. for block in code_blocks:
  326. json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE)
  327. yield parse_json(json_text)
  328. code_block_cache = ''
  329. code_block_delimiter_count = 0
  330. in_code_block = False
  331. json_cache = ''
  332. json_quote_count = 0
  333. in_json = False
  334. got_json = False
  335. for response in llm_response:
  336. response = response.delta.message.content
  337. if not isinstance(response, str):
  338. continue
  339. # stream
  340. index = 0
  341. while index < len(response):
  342. steps = 1
  343. delta = response[index:index+steps]
  344. if delta == '`':
  345. code_block_cache += delta
  346. code_block_delimiter_count += 1
  347. else:
  348. if not in_code_block:
  349. if code_block_delimiter_count > 0:
  350. yield code_block_cache
  351. code_block_cache = ''
  352. else:
  353. code_block_cache += delta
  354. code_block_delimiter_count = 0
  355. if code_block_delimiter_count == 3:
  356. if in_code_block:
  357. yield from extra_json_from_code_block(code_block_cache)
  358. code_block_cache = ''
  359. in_code_block = not in_code_block
  360. code_block_delimiter_count = 0
  361. if not in_code_block:
  362. # handle single json
  363. if delta == '{':
  364. json_quote_count += 1
  365. in_json = True
  366. json_cache += delta
  367. elif delta == '}':
  368. json_cache += delta
  369. if json_quote_count > 0:
  370. json_quote_count -= 1
  371. if json_quote_count == 0:
  372. in_json = False
  373. got_json = True
  374. index += steps
  375. continue
  376. else:
  377. if in_json:
  378. json_cache += delta
  379. if got_json:
  380. got_json = False
  381. yield parse_json(json_cache)
  382. json_cache = ''
  383. json_quote_count = 0
  384. in_json = False
  385. if not in_code_block and not in_json:
  386. yield delta.replace('`', '')
  387. index += steps
  388. if code_block_cache:
  389. yield code_block_cache
  390. if json_cache:
  391. yield parse_json(json_cache)
  392. def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
  393. """
  394. fill in inputs from external data tools
  395. """
  396. for key, value in inputs.items():
  397. try:
  398. instruction = instruction.replace(f'{{{{{key}}}}}', str(value))
  399. except Exception as e:
  400. continue
  401. return instruction
  402. def _init_agent_scratchpad(self,
  403. agent_scratchpad: list[AgentScratchpadUnit],
  404. messages: list[PromptMessage]
  405. ) -> list[AgentScratchpadUnit]:
  406. """
  407. init agent scratchpad
  408. """
  409. current_scratchpad: AgentScratchpadUnit = None
  410. for message in messages:
  411. if isinstance(message, AssistantPromptMessage):
  412. current_scratchpad = AgentScratchpadUnit(
  413. agent_response=message.content,
  414. thought=message.content,
  415. action_str='',
  416. action=None,
  417. observation=None
  418. )
  419. if message.tool_calls:
  420. try:
  421. current_scratchpad.action = AgentScratchpadUnit.Action(
  422. action_name=message.tool_calls[0].function.name,
  423. action_input=json.loads(message.tool_calls[0].function.arguments)
  424. )
  425. except:
  426. pass
  427. agent_scratchpad.append(current_scratchpad)
  428. elif isinstance(message, ToolPromptMessage):
  429. if current_scratchpad:
  430. current_scratchpad.observation = message.content
  431. return agent_scratchpad
  432. def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"],
  433. agent_prompt_message: AgentPromptEntity,
  434. ):
  435. """
  436. check chain of thought prompt messages, a standard prompt message is like:
  437. Respond to the human as helpfully and accurately as possible.
  438. {{instruction}}
  439. You have access to the following tools:
  440. {{tools}}
  441. Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
  442. Valid action values: "Final Answer" or {{tool_names}}
  443. Provide only ONE action per $JSON_BLOB, as shown:
  444. ```
  445. {
  446. "action": $TOOL_NAME,
  447. "action_input": $ACTION_INPUT
  448. }
  449. ```
  450. """
  451. # parse agent prompt message
  452. first_prompt = agent_prompt_message.first_prompt
  453. next_iteration = agent_prompt_message.next_iteration
  454. if not isinstance(first_prompt, str) or not isinstance(next_iteration, str):
  455. raise ValueError("first_prompt or next_iteration is required in CoT agent mode")
  456. # check instruction, tools, and tool_names slots
  457. if not first_prompt.find("{{instruction}}") >= 0:
  458. raise ValueError("{{instruction}} is required in first_prompt")
  459. if not first_prompt.find("{{tools}}") >= 0:
  460. raise ValueError("{{tools}} is required in first_prompt")
  461. if not first_prompt.find("{{tool_names}}") >= 0:
  462. raise ValueError("{{tool_names}} is required in first_prompt")
  463. if mode == "completion":
  464. if not first_prompt.find("{{query}}") >= 0:
  465. raise ValueError("{{query}} is required in first_prompt")
  466. if not first_prompt.find("{{agent_scratchpad}}") >= 0:
  467. raise ValueError("{{agent_scratchpad}} is required in first_prompt")
  468. if mode == "completion":
  469. if not next_iteration.find("{{observation}}") >= 0:
  470. raise ValueError("{{observation}} is required in next_iteration")
  471. def _convert_scratchpad_list_to_str(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
  472. """
  473. convert agent scratchpad list to str
  474. """
  475. next_iteration = self.app_orchestration_config.agent.prompt.next_iteration
  476. result = ''
  477. for scratchpad in agent_scratchpad:
  478. result += scratchpad.thought + next_iteration.replace("{{observation}}", scratchpad.observation or '') + "\n"
  479. return result
  480. def _organize_cot_prompt_messages(self, mode: Literal["completion", "chat"],
  481. prompt_messages: list[PromptMessage],
  482. tools: list[PromptMessageTool],
  483. agent_scratchpad: list[AgentScratchpadUnit],
  484. agent_prompt_message: AgentPromptEntity,
  485. instruction: str,
  486. input: str,
  487. ) -> list[PromptMessage]:
  488. """
  489. organize chain of thought prompt messages, a standard prompt message is like:
  490. Respond to the human as helpfully and accurately as possible.
  491. {{instruction}}
  492. You have access to the following tools:
  493. {{tools}}
  494. Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
  495. Valid action values: "Final Answer" or {{tool_names}}
  496. Provide only ONE action per $JSON_BLOB, as shown:
  497. ```
  498. {{{{
  499. "action": $TOOL_NAME,
  500. "action_input": $ACTION_INPUT
  501. }}}}
  502. ```
  503. """
  504. self._check_cot_prompt_messages(mode, agent_prompt_message)
  505. # parse agent prompt message
  506. first_prompt = agent_prompt_message.first_prompt
  507. # parse tools
  508. tools_str = self._jsonify_tool_prompt_messages(tools)
  509. # parse tools name
  510. tool_names = '"' + '","'.join([tool.name for tool in tools]) + '"'
  511. # get system message
  512. system_message = first_prompt.replace("{{instruction}}", instruction) \
  513. .replace("{{tools}}", tools_str) \
  514. .replace("{{tool_names}}", tool_names)
  515. # organize prompt messages
  516. if mode == "chat":
  517. # override system message
  518. overridden = False
  519. prompt_messages = prompt_messages.copy()
  520. for prompt_message in prompt_messages:
  521. if isinstance(prompt_message, SystemPromptMessage):
  522. prompt_message.content = system_message
  523. overridden = True
  524. break
  525. if not overridden:
  526. prompt_messages.insert(0, SystemPromptMessage(
  527. content=system_message,
  528. ))
  529. # add assistant message
  530. if len(agent_scratchpad) > 0:
  531. prompt_messages.append(AssistantPromptMessage(
  532. content=(agent_scratchpad[-1].thought or '')
  533. ))
  534. # add user message
  535. if len(agent_scratchpad) > 0:
  536. prompt_messages.append(UserPromptMessage(
  537. content=(agent_scratchpad[-1].observation or ''),
  538. ))
  539. return prompt_messages
  540. elif mode == "completion":
  541. # parse agent scratchpad
  542. agent_scratchpad_str = self._convert_scratchpad_list_to_str(agent_scratchpad)
  543. # parse prompt messages
  544. return [UserPromptMessage(
  545. content=first_prompt.replace("{{instruction}}", instruction)
  546. .replace("{{tools}}", tools_str)
  547. .replace("{{tool_names}}", tool_names)
  548. .replace("{{query}}", input)
  549. .replace("{{agent_scratchpad}}", agent_scratchpad_str),
  550. )]
  551. else:
  552. raise ValueError(f"mode {mode} is not supported")
  553. def _jsonify_tool_prompt_messages(self, tools: list[PromptMessageTool]) -> str:
  554. """
  555. jsonify tool prompt messages
  556. """
  557. tools = jsonable_encoder(tools)
  558. try:
  559. return json.dumps(tools, ensure_ascii=False)
  560. except json.JSONDecodeError:
  561. return json.dumps(tools)