123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203 |
- import re
- from typing import List, Tuple, Any, Union, Sequence, Optional
- from langchain import BasePromptTemplate
- from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
- from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
- from langchain.base_language import BaseLanguageModel
- from langchain.callbacks.base import BaseCallbackManager
- from langchain.callbacks.manager import Callbacks
- from langchain.memory.summary import SummarizerMixin
- from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
- from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException
- from langchain.tools import BaseTool
- from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
- from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
- from core.model_providers.models.llm.base import BaseLLM
- FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
- The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
- Valid "action" values: "Final Answer" or {tool_names}
- Provide only ONE action per $JSON_BLOB, as shown:
- ```
- {{{{
- "action": $TOOL_NAME,
- "action_input": $INPUT
- }}}}
- ```
- Follow this format:
- Question: input question to answer
- Thought: consider previous and subsequent steps
- Action:
- ```
- $JSON_BLOB
- ```
- Observation: action result
- ... (repeat Thought/Action/Observation N times)
- Thought: I know what to respond
- Action:
- ```
- {{{{
- "action": "Final Answer",
- "action_input": "Final response to human"
- }}}}
- ```"""
- class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
- moving_summary_buffer: str = ""
- moving_summary_index: int = 0
- summary_llm: BaseLanguageModel = None
- model_instance: BaseLLM
- class Config:
- """Configuration for this pydantic object."""
- arbitrary_types_allowed = True
- def should_use_agent(self, query: str):
- """
- return should use agent
- Using the ReACT mode to determine whether an agent is needed is costly,
- so it's better to just use an Agent for reasoning, which is cheaper.
- :param query:
- :return:
- """
- return True
- def plan(
- self,
- intermediate_steps: List[Tuple[AgentAction, str]],
- callbacks: Callbacks = None,
- **kwargs: Any,
- ) -> Union[AgentAction, AgentFinish]:
- """Given input, decided what to do.
- Args:
- intermediate_steps: Steps the LLM has taken to date,
- along with observations
- callbacks: Callbacks to run.
- **kwargs: User inputs.
- Returns:
- Action specifying what tool to use.
- """
- full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
- prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)])
- messages = []
- if prompts:
- messages = prompts[0].to_messages()
- rest_tokens = self.get_message_rest_tokens(self.model_instance, messages)
- if rest_tokens < 0:
- full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
- try:
- full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
- except Exception as e:
- new_exception = self.model_instance.handle_exceptions(e)
- raise new_exception
- try:
- agent_decision = self.output_parser.parse(full_output)
- if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
- tool_inputs = agent_decision.tool_input
- if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
- tool_inputs['query'] = kwargs['input']
- agent_decision.tool_input = tool_inputs
- return agent_decision
- except OutputParserException:
- return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
- "I don't know how to respond to that."}, "")
- def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
- if len(intermediate_steps) >= 2 and self.summary_llm:
- should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
- should_summary_messages = [AIMessage(content=observation)
- for _, observation in should_summary_intermediate_steps]
- if self.moving_summary_index == 0:
- should_summary_messages.insert(0, HumanMessage(content=kwargs.get("input")))
- self.moving_summary_index = len(intermediate_steps)
- else:
- error_msg = "Exceeded LLM tokens limit, stopped."
- raise ExceededLLMTokensLimitError(error_msg)
- summary_handler = SummarizerMixin(llm=self.summary_llm)
- if self.moving_summary_buffer and 'chat_history' in kwargs:
- kwargs["chat_history"].pop()
- self.moving_summary_buffer = summary_handler.predict_new_summary(
- messages=should_summary_messages,
- existing_summary=self.moving_summary_buffer
- )
- if 'chat_history' in kwargs:
- kwargs["chat_history"].append(AIMessage(content=self.moving_summary_buffer))
- return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
- @classmethod
- def create_prompt(
- cls,
- tools: Sequence[BaseTool],
- prefix: str = PREFIX,
- suffix: str = SUFFIX,
- human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
- format_instructions: str = FORMAT_INSTRUCTIONS,
- input_variables: Optional[List[str]] = None,
- memory_prompts: Optional[List[BasePromptTemplate]] = None,
- ) -> BasePromptTemplate:
- tool_strings = []
- for tool in tools:
- args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
- tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
- formatted_tools = "\n".join(tool_strings)
- tool_names = ", ".join([('"' + tool.name + '"') for tool in tools])
- format_instructions = format_instructions.format(tool_names=tool_names)
- template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
- if input_variables is None:
- input_variables = ["input", "agent_scratchpad"]
- _memory_prompts = memory_prompts or []
- messages = [
- SystemMessagePromptTemplate.from_template(template),
- *_memory_prompts,
- HumanMessagePromptTemplate.from_template(human_message_template),
- ]
- return ChatPromptTemplate(input_variables=input_variables, messages=messages)
- @classmethod
- def from_llm_and_tools(
- cls,
- llm: BaseLanguageModel,
- tools: Sequence[BaseTool],
- callback_manager: Optional[BaseCallbackManager] = None,
- output_parser: Optional[AgentOutputParser] = None,
- prefix: str = PREFIX,
- suffix: str = SUFFIX,
- human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
- format_instructions: str = FORMAT_INSTRUCTIONS,
- input_variables: Optional[List[str]] = None,
- memory_prompts: Optional[List[BasePromptTemplate]] = None,
- **kwargs: Any,
- ) -> Agent:
- return super().from_llm_and_tools(
- llm=llm,
- tools=tools,
- callback_manager=callback_manager,
- output_parser=output_parser,
- prefix=prefix,
- suffix=suffix,
- human_message_template=human_message_template,
- format_instructions=format_instructions,
- input_variables=input_variables,
- memory_prompts=memory_prompts,
- **kwargs,
- )
|