tool_chain.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. from typing import List, Dict, Optional, Any
  2. from langchain.callbacks.manager import CallbackManagerForChainRun, AsyncCallbackManagerForChainRun
  3. from langchain.chains.base import Chain
  4. from langchain.tools import BaseTool
  5. class ToolChain(Chain):
  6. input_key: str = "input" #: :meta private:
  7. output_key: str = "output" #: :meta private:
  8. tool: BaseTool
  9. @property
  10. def _chain_type(self) -> str:
  11. return "tool_chain"
  12. @property
  13. def input_keys(self) -> List[str]:
  14. """Expect input key.
  15. :meta private:
  16. """
  17. return [self.input_key]
  18. @property
  19. def output_keys(self) -> List[str]:
  20. """Return output key.
  21. :meta private:
  22. """
  23. return [self.output_key]
  24. def _call(
  25. self,
  26. inputs: Dict[str, Any],
  27. run_manager: Optional[CallbackManagerForChainRun] = None,
  28. ) -> Dict[str, Any]:
  29. input = inputs[self.input_key]
  30. output = self.tool.run(input, self.verbose)
  31. return {self.output_key: output}
  32. async def _acall(
  33. self,
  34. inputs: Dict[str, Any],
  35. run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
  36. ) -> Dict[str, Any]:
  37. """Run the logic of this chain and return the output."""
  38. input = inputs[self.input_key]
  39. output = await self.tool.arun(input, self.verbose)
  40. return {self.output_key: output}