agent_executor.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import enum
  2. import logging
  3. from typing import Union, Optional
  4. from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
  5. from langchain.base_language import BaseLanguageModel
  6. from langchain.callbacks.manager import Callbacks
  7. from langchain.memory.chat_memory import BaseChatMemory
  8. from langchain.tools import BaseTool
  9. from pydantic import BaseModel, Extra
  10. from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
  11. from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
  12. from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
  13. from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
  14. from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
  15. from langchain.agents import AgentExecutor as LCAgentExecutor
  16. from core.tool.dataset_retriever_tool import DatasetRetrieverTool
  17. class PlanningStrategy(str, enum.Enum):
  18. ROUTER = 'router'
  19. REACT = 'react'
  20. FUNCTION_CALL = 'function_call'
  21. MULTI_FUNCTION_CALL = 'multi_function_call'
  22. class AgentConfiguration(BaseModel):
  23. strategy: PlanningStrategy
  24. llm: BaseLanguageModel
  25. tools: list[BaseTool]
  26. summary_llm: BaseLanguageModel
  27. dataset_llm: BaseLanguageModel
  28. memory: Optional[BaseChatMemory] = None
  29. callbacks: Callbacks = None
  30. max_iterations: int = 6
  31. max_execution_time: Optional[float] = None
  32. early_stopping_method: str = "generate"
  33. # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
  34. class Config:
  35. """Configuration for this pydantic object."""
  36. extra = Extra.forbid
  37. arbitrary_types_allowed = True
  38. class AgentExecuteResult(BaseModel):
  39. strategy: PlanningStrategy
  40. output: Optional[str]
  41. configuration: AgentConfiguration
  42. class AgentExecutor:
  43. def __init__(self, configuration: AgentConfiguration):
  44. self.configuration = configuration
  45. self.agent = self._init_agent()
  46. def _init_agent(self) -> Union[BaseSingleActionAgent | BaseMultiActionAgent]:
  47. if self.configuration.strategy == PlanningStrategy.REACT:
  48. agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
  49. llm=self.configuration.llm,
  50. tools=self.configuration.tools,
  51. output_parser=StructuredChatOutputParser(),
  52. summary_llm=self.configuration.summary_llm,
  53. verbose=True
  54. )
  55. elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
  56. agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
  57. llm=self.configuration.llm,
  58. tools=self.configuration.tools,
  59. extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
  60. summary_llm=self.configuration.summary_llm,
  61. verbose=True
  62. )
  63. elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
  64. agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools(
  65. llm=self.configuration.llm,
  66. tools=self.configuration.tools,
  67. extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
  68. summary_llm=self.configuration.summary_llm,
  69. verbose=True
  70. )
  71. elif self.configuration.strategy == PlanningStrategy.ROUTER:
  72. self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
  73. agent = MultiDatasetRouterAgent.from_llm_and_tools(
  74. llm=self.configuration.dataset_llm,
  75. tools=self.configuration.tools,
  76. extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
  77. verbose=True
  78. )
  79. else:
  80. raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}")
  81. return agent
  82. def should_use_agent(self, query: str) -> bool:
  83. return self.agent.should_use_agent(query)
  84. def run(self, query: str) -> AgentExecuteResult:
  85. agent_executor = LCAgentExecutor.from_agent_and_tools(
  86. agent=self.agent,
  87. tools=self.configuration.tools,
  88. memory=self.configuration.memory,
  89. max_iterations=self.configuration.max_iterations,
  90. max_execution_time=self.configuration.max_execution_time,
  91. early_stopping_method=self.configuration.early_stopping_method,
  92. callbacks=self.configuration.callbacks
  93. )
  94. try:
  95. output = agent_executor.run(query)
  96. except Exception:
  97. logging.exception("agent_executor run failed")
  98. output = None
  99. return AgentExecuteResult(
  100. output=output,
  101. strategy=self.configuration.strategy,
  102. configuration=self.configuration
  103. )