structured_chat.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. import re
  2. from collections.abc import Sequence
  3. from typing import Any, Optional, Union, cast
  4. from langchain import BasePromptTemplate, PromptTemplate
  5. from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
  6. from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
  7. from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
  8. from langchain.callbacks.base import BaseCallbackManager
  9. from langchain.callbacks.manager import Callbacks
  10. from langchain.memory.prompt import SUMMARY_PROMPT
  11. from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
  12. from langchain.schema import (
  13. AgentAction,
  14. AgentFinish,
  15. AIMessage,
  16. BaseMessage,
  17. HumanMessage,
  18. OutputParserException,
  19. get_buffer_string,
  20. )
  21. from langchain.tools import BaseTool
  22. from core.agent.agent.agent_llm_callback import AgentLLMCallback
  23. from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
  24. from core.chain.llm_chain import LLMChain
  25. from core.entities.application_entities import ModelConfigEntity
  26. from core.entities.message_entities import lc_messages_to_prompt_messages
  27. FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
  28. The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
  29. Valid "action" values: "Final Answer" or {tool_names}
  30. Provide only ONE action per $JSON_BLOB, as shown:
  31. ```
  32. {{{{
  33. "action": $TOOL_NAME,
  34. "action_input": $INPUT
  35. }}}}
  36. ```
  37. Follow this format:
  38. Question: input question to answer
  39. Thought: consider previous and subsequent steps
  40. Action:
  41. ```
  42. $JSON_BLOB
  43. ```
  44. Observation: action result
  45. ... (repeat Thought/Action/Observation N times)
  46. Thought: I know what to respond
  47. Action:
  48. ```
  49. {{{{
  50. "action": "Final Answer",
  51. "action_input": "Final response to human"
  52. }}}}
  53. ```"""
  54. class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
  55. moving_summary_buffer: str = ""
  56. moving_summary_index: int = 0
  57. summary_model_config: ModelConfigEntity = None
  58. class Config:
  59. """Configuration for this pydantic object."""
  60. arbitrary_types_allowed = True
  61. def should_use_agent(self, query: str):
  62. """
  63. return should use agent
  64. Using the ReACT mode to determine whether an agent is needed is costly,
  65. so it's better to just use an Agent for reasoning, which is cheaper.
  66. :param query:
  67. :return:
  68. """
  69. return True
  70. def plan(
  71. self,
  72. intermediate_steps: list[tuple[AgentAction, str]],
  73. callbacks: Callbacks = None,
  74. **kwargs: Any,
  75. ) -> Union[AgentAction, AgentFinish]:
  76. """Given input, decided what to do.
  77. Args:
  78. intermediate_steps: Steps the LLM has taken to date,
  79. along with observatons
  80. callbacks: Callbacks to run.
  81. **kwargs: User inputs.
  82. Returns:
  83. Action specifying what tool to use.
  84. """
  85. full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
  86. prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)])
  87. messages = []
  88. if prompts:
  89. messages = prompts[0].to_messages()
  90. prompt_messages = lc_messages_to_prompt_messages(messages)
  91. rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_config, prompt_messages)
  92. if rest_tokens < 0:
  93. full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
  94. try:
  95. full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
  96. except Exception as e:
  97. raise e
  98. try:
  99. agent_decision = self.output_parser.parse(full_output)
  100. if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
  101. tool_inputs = agent_decision.tool_input
  102. if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
  103. tool_inputs['query'] = kwargs['input']
  104. agent_decision.tool_input = tool_inputs
  105. return agent_decision
  106. except OutputParserException:
  107. return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
  108. "I don't know how to respond to that."}, "")
  109. def summarize_messages(self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs):
  110. if len(intermediate_steps) >= 2 and self.summary_model_config:
  111. should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
  112. should_summary_messages = [AIMessage(content=observation)
  113. for _, observation in should_summary_intermediate_steps]
  114. if self.moving_summary_index == 0:
  115. should_summary_messages.insert(0, HumanMessage(content=kwargs.get("input")))
  116. self.moving_summary_index = len(intermediate_steps)
  117. else:
  118. error_msg = "Exceeded LLM tokens limit, stopped."
  119. raise ExceededLLMTokensLimitError(error_msg)
  120. if self.moving_summary_buffer and 'chat_history' in kwargs:
  121. kwargs["chat_history"].pop()
  122. self.moving_summary_buffer = self.predict_new_summary(
  123. messages=should_summary_messages,
  124. existing_summary=self.moving_summary_buffer
  125. )
  126. if 'chat_history' in kwargs:
  127. kwargs["chat_history"].append(AIMessage(content=self.moving_summary_buffer))
  128. return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
  129. def predict_new_summary(
  130. self, messages: list[BaseMessage], existing_summary: str
  131. ) -> str:
  132. new_lines = get_buffer_string(
  133. messages,
  134. human_prefix="Human",
  135. ai_prefix="AI",
  136. )
  137. chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
  138. return chain.predict(summary=existing_summary, new_lines=new_lines)
  139. @classmethod
  140. def create_prompt(
  141. cls,
  142. tools: Sequence[BaseTool],
  143. prefix: str = PREFIX,
  144. suffix: str = SUFFIX,
  145. human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
  146. format_instructions: str = FORMAT_INSTRUCTIONS,
  147. input_variables: Optional[list[str]] = None,
  148. memory_prompts: Optional[list[BasePromptTemplate]] = None,
  149. ) -> BasePromptTemplate:
  150. tool_strings = []
  151. for tool in tools:
  152. args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
  153. tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
  154. formatted_tools = "\n".join(tool_strings)
  155. tool_names = ", ".join([('"' + tool.name + '"') for tool in tools])
  156. format_instructions = format_instructions.format(tool_names=tool_names)
  157. template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
  158. if input_variables is None:
  159. input_variables = ["input", "agent_scratchpad"]
  160. _memory_prompts = memory_prompts or []
  161. messages = [
  162. SystemMessagePromptTemplate.from_template(template),
  163. *_memory_prompts,
  164. HumanMessagePromptTemplate.from_template(human_message_template),
  165. ]
  166. return ChatPromptTemplate(input_variables=input_variables, messages=messages)
  167. @classmethod
  168. def create_completion_prompt(
  169. cls,
  170. tools: Sequence[BaseTool],
  171. prefix: str = PREFIX,
  172. format_instructions: str = FORMAT_INSTRUCTIONS,
  173. input_variables: Optional[list[str]] = None,
  174. ) -> PromptTemplate:
  175. """Create prompt in the style of the zero shot agent.
  176. Args:
  177. tools: List of tools the agent will have access to, used to format the
  178. prompt.
  179. prefix: String to put before the list of tools.
  180. input_variables: List of input variables the final prompt will expect.
  181. Returns:
  182. A PromptTemplate with the template assembled from the pieces here.
  183. """
  184. suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
  185. Question: {input}
  186. Thought: {agent_scratchpad}
  187. """
  188. tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
  189. tool_names = ", ".join([tool.name for tool in tools])
  190. format_instructions = format_instructions.format(tool_names=tool_names)
  191. template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
  192. if input_variables is None:
  193. input_variables = ["input", "agent_scratchpad"]
  194. return PromptTemplate(template=template, input_variables=input_variables)
  195. def _construct_scratchpad(
  196. self, intermediate_steps: list[tuple[AgentAction, str]]
  197. ) -> str:
  198. agent_scratchpad = ""
  199. for action, observation in intermediate_steps:
  200. agent_scratchpad += action.log
  201. agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
  202. if not isinstance(agent_scratchpad, str):
  203. raise ValueError("agent_scratchpad should be of type string.")
  204. if agent_scratchpad:
  205. llm_chain = cast(LLMChain, self.llm_chain)
  206. if llm_chain.model_config.mode == "chat":
  207. return (
  208. f"This was your previous work "
  209. f"(but I haven't seen any of it! I only see what "
  210. f"you return as final answer):\n{agent_scratchpad}"
  211. )
  212. else:
  213. return agent_scratchpad
  214. else:
  215. return agent_scratchpad
  216. @classmethod
  217. def from_llm_and_tools(
  218. cls,
  219. model_config: ModelConfigEntity,
  220. tools: Sequence[BaseTool],
  221. callback_manager: Optional[BaseCallbackManager] = None,
  222. output_parser: Optional[AgentOutputParser] = None,
  223. prefix: str = PREFIX,
  224. suffix: str = SUFFIX,
  225. human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
  226. format_instructions: str = FORMAT_INSTRUCTIONS,
  227. input_variables: Optional[list[str]] = None,
  228. memory_prompts: Optional[list[BasePromptTemplate]] = None,
  229. agent_llm_callback: Optional[AgentLLMCallback] = None,
  230. **kwargs: Any,
  231. ) -> Agent:
  232. """Construct an agent from an LLM and tools."""
  233. cls._validate_tools(tools)
  234. if model_config.mode == "chat":
  235. prompt = cls.create_prompt(
  236. tools,
  237. prefix=prefix,
  238. suffix=suffix,
  239. human_message_template=human_message_template,
  240. format_instructions=format_instructions,
  241. input_variables=input_variables,
  242. memory_prompts=memory_prompts,
  243. )
  244. else:
  245. prompt = cls.create_completion_prompt(
  246. tools,
  247. prefix=prefix,
  248. format_instructions=format_instructions,
  249. input_variables=input_variables,
  250. )
  251. llm_chain = LLMChain(
  252. model_config=model_config,
  253. prompt=prompt,
  254. callback_manager=callback_manager,
  255. agent_llm_callback=agent_llm_callback,
  256. parameters={
  257. 'temperature': 0.2,
  258. 'top_p': 0.3,
  259. 'max_tokens': 1500
  260. }
  261. )
  262. tool_names = [tool.name for tool in tools]
  263. _output_parser = output_parser
  264. return cls(
  265. llm_chain=llm_chain,
  266. allowed_tools=tool_names,
  267. output_parser=_output_parser,
  268. **kwargs,
  269. )