chain_builder.py 1.2 KB

1234567891011121314151617181920212223242526272829303132
  1. from typing import Optional
  2. from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
  3. from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
  4. from core.chain.tool_chain import ToolChain
  5. class ChainBuilder:
  6. @classmethod
  7. def to_tool_chain(cls, tool, **kwargs) -> ToolChain:
  8. return ToolChain(
  9. tool=tool,
  10. input_key=kwargs.get('input_key', 'input'),
  11. output_key=kwargs.get('output_key', 'tool_output'),
  12. callbacks=[DifyStdOutCallbackHandler()]
  13. )
  14. @classmethod
  15. def to_sensitive_word_avoidance_chain(cls, tool_config: dict, **kwargs) -> Optional[
  16. SensitiveWordAvoidanceChain]:
  17. sensitive_words = tool_config.get("words", "")
  18. if tool_config.get("enabled", False) \
  19. and sensitive_words:
  20. return SensitiveWordAvoidanceChain(
  21. sensitive_words=sensitive_words.split(","),
  22. canned_response=tool_config.get("canned_response", ''),
  23. output_key="sensitive_word_avoidance_output",
  24. callbacks=[DifyStdOutCallbackHandler()],
  25. **kwargs
  26. )
  27. return None