123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122 |
- import enum
- import logging
- from typing import Union, Optional
- from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
- from langchain.base_language import BaseLanguageModel
- from langchain.callbacks.manager import Callbacks
- from langchain.memory.chat_memory import BaseChatMemory
- from langchain.tools import BaseTool
- from pydantic import BaseModel, Extra
- from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
- from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
- from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
- from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
- from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
- from langchain.agents import AgentExecutor as LCAgentExecutor
- from core.tool.dataset_retriever_tool import DatasetRetrieverTool
- class PlanningStrategy(str, enum.Enum):
- ROUTER = 'router'
- REACT = 'react'
- FUNCTION_CALL = 'function_call'
- MULTI_FUNCTION_CALL = 'multi_function_call'
- class AgentConfiguration(BaseModel):
- strategy: PlanningStrategy
- llm: BaseLanguageModel
- tools: list[BaseTool]
- summary_llm: BaseLanguageModel
- dataset_llm: BaseLanguageModel
- memory: Optional[BaseChatMemory] = None
- callbacks: Callbacks = None
- max_iterations: int = 6
- max_execution_time: Optional[float] = None
- early_stopping_method: str = "generate"
- # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
- class Config:
- """Configuration for this pydantic object."""
- extra = Extra.forbid
- arbitrary_types_allowed = True
- class AgentExecuteResult(BaseModel):
- strategy: PlanningStrategy
- output: Optional[str]
- configuration: AgentConfiguration
- class AgentExecutor:
- def __init__(self, configuration: AgentConfiguration):
- self.configuration = configuration
- self.agent = self._init_agent()
- def _init_agent(self) -> Union[BaseSingleActionAgent | BaseMultiActionAgent]:
- if self.configuration.strategy == PlanningStrategy.REACT:
- agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
- llm=self.configuration.llm,
- tools=self.configuration.tools,
- output_parser=StructuredChatOutputParser(),
- summary_llm=self.configuration.summary_llm,
- verbose=True
- )
- elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
- agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
- llm=self.configuration.llm,
- tools=self.configuration.tools,
- extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
- summary_llm=self.configuration.summary_llm,
- verbose=True
- )
- elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
- agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools(
- llm=self.configuration.llm,
- tools=self.configuration.tools,
- extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
- summary_llm=self.configuration.summary_llm,
- verbose=True
- )
- elif self.configuration.strategy == PlanningStrategy.ROUTER:
- self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
- agent = MultiDatasetRouterAgent.from_llm_and_tools(
- llm=self.configuration.dataset_llm,
- tools=self.configuration.tools,
- extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
- verbose=True
- )
- else:
- raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}")
- return agent
- def should_use_agent(self, query: str) -> bool:
- return self.agent.should_use_agent(query)
- def run(self, query: str) -> AgentExecuteResult:
- agent_executor = LCAgentExecutor.from_agent_and_tools(
- agent=self.agent,
- tools=self.configuration.tools,
- memory=self.configuration.memory,
- max_iterations=self.configuration.max_iterations,
- max_execution_time=self.configuration.max_execution_time,
- early_stopping_method=self.configuration.early_stopping_method,
- callbacks=self.configuration.callbacks
- )
- try:
- output = agent_executor.run(query)
- except Exception:
- logging.exception("agent_executor run failed")
- output = None
- return AgentExecuteResult(
- output=output,
- strategy=self.configuration.strategy,
- configuration=self.configuration
- )
|