structured_chat.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. import re
  2. from typing import List, Tuple, Any, Union, Sequence, Optional, cast
  3. from langchain import BasePromptTemplate, PromptTemplate
  4. from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
  5. from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
  6. from langchain.callbacks.base import BaseCallbackManager
  7. from langchain.callbacks.manager import Callbacks
  8. from langchain.memory.prompt import SUMMARY_PROMPT
  9. from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
  10. from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException, BaseMessage, \
  11. get_buffer_string
  12. from langchain.tools import BaseTool
  13. from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
  14. from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
  15. from core.chain.llm_chain import LLMChain
  16. from core.model_providers.models.entity.model_params import ModelMode
  17. from core.model_providers.models.llm.base import BaseLLM
  18. FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
  19. The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
  20. Valid "action" values: "Final Answer" or {tool_names}
  21. Provide only ONE action per $JSON_BLOB, as shown:
  22. ```
  23. {{{{
  24. "action": $TOOL_NAME,
  25. "action_input": $INPUT
  26. }}}}
  27. ```
  28. Follow this format:
  29. Question: input question to answer
  30. Thought: consider previous and subsequent steps
  31. Action:
  32. ```
  33. $JSON_BLOB
  34. ```
  35. Observation: action result
  36. ... (repeat Thought/Action/Observation N times)
  37. Thought: I know what to respond
  38. Action:
  39. ```
  40. {{{{
  41. "action": "Final Answer",
  42. "action_input": "Final response to human"
  43. }}}}
  44. ```"""
  45. class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
  46. moving_summary_buffer: str = ""
  47. moving_summary_index: int = 0
  48. summary_model_instance: BaseLLM = None
  49. class Config:
  50. """Configuration for this pydantic object."""
  51. arbitrary_types_allowed = True
  52. def should_use_agent(self, query: str):
  53. """
  54. return should use agent
  55. Using the ReACT mode to determine whether an agent is needed is costly,
  56. so it's better to just use an Agent for reasoning, which is cheaper.
  57. :param query:
  58. :return:
  59. """
  60. return True
  61. def plan(
  62. self,
  63. intermediate_steps: List[Tuple[AgentAction, str]],
  64. callbacks: Callbacks = None,
  65. **kwargs: Any,
  66. ) -> Union[AgentAction, AgentFinish]:
  67. """Given input, decided what to do.
  68. Args:
  69. intermediate_steps: Steps the LLM has taken to date,
  70. along with observations
  71. callbacks: Callbacks to run.
  72. **kwargs: User inputs.
  73. Returns:
  74. Action specifying what tool to use.
  75. """
  76. full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
  77. prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)])
  78. messages = []
  79. if prompts:
  80. messages = prompts[0].to_messages()
  81. rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_instance, messages)
  82. if rest_tokens < 0:
  83. full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
  84. try:
  85. full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
  86. except Exception as e:
  87. new_exception = self.llm_chain.model_instance.handle_exceptions(e)
  88. raise new_exception
  89. try:
  90. agent_decision = self.output_parser.parse(full_output)
  91. if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
  92. tool_inputs = agent_decision.tool_input
  93. if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
  94. tool_inputs['query'] = kwargs['input']
  95. agent_decision.tool_input = tool_inputs
  96. return agent_decision
  97. except OutputParserException:
  98. return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
  99. "I don't know how to respond to that."}, "")
  100. def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
  101. if len(intermediate_steps) >= 2 and self.summary_model_instance:
  102. should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
  103. should_summary_messages = [AIMessage(content=observation)
  104. for _, observation in should_summary_intermediate_steps]
  105. if self.moving_summary_index == 0:
  106. should_summary_messages.insert(0, HumanMessage(content=kwargs.get("input")))
  107. self.moving_summary_index = len(intermediate_steps)
  108. else:
  109. error_msg = "Exceeded LLM tokens limit, stopped."
  110. raise ExceededLLMTokensLimitError(error_msg)
  111. if self.moving_summary_buffer and 'chat_history' in kwargs:
  112. kwargs["chat_history"].pop()
  113. self.moving_summary_buffer = self.predict_new_summary(
  114. messages=should_summary_messages,
  115. existing_summary=self.moving_summary_buffer
  116. )
  117. if 'chat_history' in kwargs:
  118. kwargs["chat_history"].append(AIMessage(content=self.moving_summary_buffer))
  119. return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
  120. def predict_new_summary(
  121. self, messages: List[BaseMessage], existing_summary: str
  122. ) -> str:
  123. new_lines = get_buffer_string(
  124. messages,
  125. human_prefix="Human",
  126. ai_prefix="AI",
  127. )
  128. chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
  129. return chain.predict(summary=existing_summary, new_lines=new_lines)
  130. @classmethod
  131. def create_prompt(
  132. cls,
  133. tools: Sequence[BaseTool],
  134. prefix: str = PREFIX,
  135. suffix: str = SUFFIX,
  136. human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
  137. format_instructions: str = FORMAT_INSTRUCTIONS,
  138. input_variables: Optional[List[str]] = None,
  139. memory_prompts: Optional[List[BasePromptTemplate]] = None,
  140. ) -> BasePromptTemplate:
  141. tool_strings = []
  142. for tool in tools:
  143. args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
  144. tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
  145. formatted_tools = "\n".join(tool_strings)
  146. tool_names = ", ".join([('"' + tool.name + '"') for tool in tools])
  147. format_instructions = format_instructions.format(tool_names=tool_names)
  148. template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
  149. if input_variables is None:
  150. input_variables = ["input", "agent_scratchpad"]
  151. _memory_prompts = memory_prompts or []
  152. messages = [
  153. SystemMessagePromptTemplate.from_template(template),
  154. *_memory_prompts,
  155. HumanMessagePromptTemplate.from_template(human_message_template),
  156. ]
  157. return ChatPromptTemplate(input_variables=input_variables, messages=messages)
  158. @classmethod
  159. def create_completion_prompt(
  160. cls,
  161. tools: Sequence[BaseTool],
  162. prefix: str = PREFIX,
  163. format_instructions: str = FORMAT_INSTRUCTIONS,
  164. input_variables: Optional[List[str]] = None,
  165. ) -> PromptTemplate:
  166. """Create prompt in the style of the zero shot agent.
  167. Args:
  168. tools: List of tools the agent will have access to, used to format the
  169. prompt.
  170. prefix: String to put before the list of tools.
  171. input_variables: List of input variables the final prompt will expect.
  172. Returns:
  173. A PromptTemplate with the template assembled from the pieces here.
  174. """
  175. suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
  176. Question: {input}
  177. Thought: {agent_scratchpad}
  178. """
  179. tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
  180. tool_names = ", ".join([tool.name for tool in tools])
  181. format_instructions = format_instructions.format(tool_names=tool_names)
  182. template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
  183. if input_variables is None:
  184. input_variables = ["input", "agent_scratchpad"]
  185. return PromptTemplate(template=template, input_variables=input_variables)
  186. def _construct_scratchpad(
  187. self, intermediate_steps: List[Tuple[AgentAction, str]]
  188. ) -> str:
  189. agent_scratchpad = ""
  190. for action, observation in intermediate_steps:
  191. agent_scratchpad += action.log
  192. agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
  193. if not isinstance(agent_scratchpad, str):
  194. raise ValueError("agent_scratchpad should be of type string.")
  195. if agent_scratchpad:
  196. llm_chain = cast(LLMChain, self.llm_chain)
  197. if llm_chain.model_instance.model_mode == ModelMode.CHAT:
  198. return (
  199. f"This was your previous work "
  200. f"(but I haven't seen any of it! I only see what "
  201. f"you return as final answer):\n{agent_scratchpad}"
  202. )
  203. else:
  204. return agent_scratchpad
  205. else:
  206. return agent_scratchpad
  207. @classmethod
  208. def from_llm_and_tools(
  209. cls,
  210. model_instance: BaseLLM,
  211. tools: Sequence[BaseTool],
  212. callback_manager: Optional[BaseCallbackManager] = None,
  213. output_parser: Optional[AgentOutputParser] = None,
  214. prefix: str = PREFIX,
  215. suffix: str = SUFFIX,
  216. human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
  217. format_instructions: str = FORMAT_INSTRUCTIONS,
  218. input_variables: Optional[List[str]] = None,
  219. memory_prompts: Optional[List[BasePromptTemplate]] = None,
  220. **kwargs: Any,
  221. ) -> Agent:
  222. """Construct an agent from an LLM and tools."""
  223. cls._validate_tools(tools)
  224. if model_instance.model_mode == ModelMode.CHAT:
  225. prompt = cls.create_prompt(
  226. tools,
  227. prefix=prefix,
  228. suffix=suffix,
  229. human_message_template=human_message_template,
  230. format_instructions=format_instructions,
  231. input_variables=input_variables,
  232. memory_prompts=memory_prompts,
  233. )
  234. else:
  235. prompt = cls.create_completion_prompt(
  236. tools,
  237. prefix=prefix,
  238. format_instructions=format_instructions,
  239. input_variables=input_variables,
  240. )
  241. llm_chain = LLMChain(
  242. model_instance=model_instance,
  243. prompt=prompt,
  244. callback_manager=callback_manager,
  245. )
  246. tool_names = [tool.name for tool in tools]
  247. _output_parser = output_parser
  248. return cls(
  249. llm_chain=llm_chain,
  250. allowed_tools=tool_names,
  251. output_parser=_output_parser,
  252. **kwargs,
  253. )