structured_chat.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. import re
  2. from typing import List, Tuple, Any, Union, Sequence, Optional
  3. from langchain import BasePromptTemplate
  4. from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
  5. from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
  6. from langchain.base_language import BaseLanguageModel
  7. from langchain.callbacks.base import BaseCallbackManager
  8. from langchain.callbacks.manager import Callbacks
  9. from langchain.memory.summary import SummarizerMixin
  10. from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
  11. from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException
  12. from langchain.tools import BaseTool
  13. from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
  14. from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
  15. from core.model_providers.models.llm.base import BaseLLM
  16. FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
  17. The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
  18. Valid "action" values: "Final Answer" or {tool_names}
  19. Provide only ONE action per $JSON_BLOB, as shown:
  20. ```
  21. {{{{
  22. "action": $TOOL_NAME,
  23. "action_input": $INPUT
  24. }}}}
  25. ```
  26. Follow this format:
  27. Question: input question to answer
  28. Thought: consider previous and subsequent steps
  29. Action:
  30. ```
  31. $JSON_BLOB
  32. ```
  33. Observation: action result
  34. ... (repeat Thought/Action/Observation N times)
  35. Thought: I know what to respond
  36. Action:
  37. ```
  38. {{{{
  39. "action": "Final Answer",
  40. "action_input": "Final response to human"
  41. }}}}
  42. ```"""
  43. class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
  44. moving_summary_buffer: str = ""
  45. moving_summary_index: int = 0
  46. summary_llm: BaseLanguageModel = None
  47. model_instance: BaseLLM
  48. class Config:
  49. """Configuration for this pydantic object."""
  50. arbitrary_types_allowed = True
  51. def should_use_agent(self, query: str):
  52. """
  53. return should use agent
  54. Using the ReACT mode to determine whether an agent is needed is costly,
  55. so it's better to just use an Agent for reasoning, which is cheaper.
  56. :param query:
  57. :return:
  58. """
  59. return True
  60. def plan(
  61. self,
  62. intermediate_steps: List[Tuple[AgentAction, str]],
  63. callbacks: Callbacks = None,
  64. **kwargs: Any,
  65. ) -> Union[AgentAction, AgentFinish]:
  66. """Given input, decided what to do.
  67. Args:
  68. intermediate_steps: Steps the LLM has taken to date,
  69. along with observations
  70. callbacks: Callbacks to run.
  71. **kwargs: User inputs.
  72. Returns:
  73. Action specifying what tool to use.
  74. """
  75. full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
  76. prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)])
  77. messages = []
  78. if prompts:
  79. messages = prompts[0].to_messages()
  80. rest_tokens = self.get_message_rest_tokens(self.model_instance, messages)
  81. if rest_tokens < 0:
  82. full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
  83. try:
  84. full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
  85. except Exception as e:
  86. new_exception = self.model_instance.handle_exceptions(e)
  87. raise new_exception
  88. try:
  89. agent_decision = self.output_parser.parse(full_output)
  90. if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
  91. tool_inputs = agent_decision.tool_input
  92. if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
  93. tool_inputs['query'] = kwargs['input']
  94. agent_decision.tool_input = tool_inputs
  95. return agent_decision
  96. except OutputParserException:
  97. return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
  98. "I don't know how to respond to that."}, "")
  99. def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
  100. if len(intermediate_steps) >= 2 and self.summary_llm:
  101. should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
  102. should_summary_messages = [AIMessage(content=observation)
  103. for _, observation in should_summary_intermediate_steps]
  104. if self.moving_summary_index == 0:
  105. should_summary_messages.insert(0, HumanMessage(content=kwargs.get("input")))
  106. self.moving_summary_index = len(intermediate_steps)
  107. else:
  108. error_msg = "Exceeded LLM tokens limit, stopped."
  109. raise ExceededLLMTokensLimitError(error_msg)
  110. summary_handler = SummarizerMixin(llm=self.summary_llm)
  111. if self.moving_summary_buffer and 'chat_history' in kwargs:
  112. kwargs["chat_history"].pop()
  113. self.moving_summary_buffer = summary_handler.predict_new_summary(
  114. messages=should_summary_messages,
  115. existing_summary=self.moving_summary_buffer
  116. )
  117. if 'chat_history' in kwargs:
  118. kwargs["chat_history"].append(AIMessage(content=self.moving_summary_buffer))
  119. return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
  120. @classmethod
  121. def create_prompt(
  122. cls,
  123. tools: Sequence[BaseTool],
  124. prefix: str = PREFIX,
  125. suffix: str = SUFFIX,
  126. human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
  127. format_instructions: str = FORMAT_INSTRUCTIONS,
  128. input_variables: Optional[List[str]] = None,
  129. memory_prompts: Optional[List[BasePromptTemplate]] = None,
  130. ) -> BasePromptTemplate:
  131. tool_strings = []
  132. for tool in tools:
  133. args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
  134. tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
  135. formatted_tools = "\n".join(tool_strings)
  136. tool_names = ", ".join([('"' + tool.name + '"') for tool in tools])
  137. format_instructions = format_instructions.format(tool_names=tool_names)
  138. template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
  139. if input_variables is None:
  140. input_variables = ["input", "agent_scratchpad"]
  141. _memory_prompts = memory_prompts or []
  142. messages = [
  143. SystemMessagePromptTemplate.from_template(template),
  144. *_memory_prompts,
  145. HumanMessagePromptTemplate.from_template(human_message_template),
  146. ]
  147. return ChatPromptTemplate(input_variables=input_variables, messages=messages)
  148. @classmethod
  149. def from_llm_and_tools(
  150. cls,
  151. llm: BaseLanguageModel,
  152. tools: Sequence[BaseTool],
  153. callback_manager: Optional[BaseCallbackManager] = None,
  154. output_parser: Optional[AgentOutputParser] = None,
  155. prefix: str = PREFIX,
  156. suffix: str = SUFFIX,
  157. human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
  158. format_instructions: str = FORMAT_INSTRUCTIONS,
  159. input_variables: Optional[List[str]] = None,
  160. memory_prompts: Optional[List[BasePromptTemplate]] = None,
  161. **kwargs: Any,
  162. ) -> Agent:
  163. return super().from_llm_and_tools(
  164. llm=llm,
  165. tools=tools,
  166. callback_manager=callback_manager,
  167. output_parser=output_parser,
  168. prefix=prefix,
  169. suffix=suffix,
  170. human_message_template=human_message_template,
  171. format_instructions=format_instructions,
  172. input_variables=input_variables,
  173. memory_prompts=memory_prompts,
  174. **kwargs,
  175. )