chain_builder.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334
  1. from typing import Optional
  2. from langchain.callbacks import CallbackManager
  3. from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
  4. from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
  5. from core.chain.tool_chain import ToolChain
  6. class ChainBuilder:
  7. @classmethod
  8. def to_tool_chain(cls, tool, **kwargs) -> ToolChain:
  9. return ToolChain(
  10. tool=tool,
  11. input_key=kwargs.get('input_key', 'input'),
  12. output_key=kwargs.get('output_key', 'tool_output'),
  13. callback_manager=CallbackManager([DifyStdOutCallbackHandler()])
  14. )
  15. @classmethod
  16. def to_sensitive_word_avoidance_chain(cls, tool_config: dict, **kwargs) -> Optional[
  17. SensitiveWordAvoidanceChain]:
  18. sensitive_words = tool_config.get("words", "")
  19. if tool_config.get("enabled", False) \
  20. and sensitive_words:
  21. return SensitiveWordAvoidanceChain(
  22. sensitive_words=sensitive_words.split(","),
  23. canned_response=tool_config.get("canned_response", ''),
  24. output_key="sensitive_word_avoidance_output",
  25. callback_manager=CallbackManager([DifyStdOutCallbackHandler()]),
  26. **kwargs
  27. )
  28. return None