main_chain_builder.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. from typing import Optional, List
  2. from langchain.callbacks import SharedCallbackManager
  3. from langchain.chains import SequentialChain
  4. from langchain.chains.base import Chain
  5. from langchain.memory.chat_memory import BaseChatMemory
  6. from core.agent.agent_builder import AgentBuilder
  7. from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
  8. from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
  9. from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
  10. from core.chain.chain_builder import ChainBuilder
  11. from core.constant import llm_constant
  12. from core.conversation_message_task import ConversationMessageTask
  13. from core.tool.dataset_tool_builder import DatasetToolBuilder
  14. class MainChainBuilder:
  15. @classmethod
  16. def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
  17. conversation_message_task: ConversationMessageTask):
  18. first_input_key = "input"
  19. final_output_key = "output"
  20. chains = []
  21. chain_callback_handler = MainChainGatherCallbackHandler(conversation_message_task)
  22. # agent mode
  23. tool_chains, chains_output_key = cls.get_agent_chains(
  24. tenant_id=tenant_id,
  25. agent_mode=agent_mode,
  26. memory=memory,
  27. dataset_tool_callback_handler=DatasetToolCallbackHandler(conversation_message_task),
  28. agent_loop_gather_callback_handler=chain_callback_handler.agent_loop_gather_callback_handler
  29. )
  30. chains += tool_chains
  31. if chains_output_key:
  32. final_output_key = chains_output_key
  33. if len(chains) == 0:
  34. return None
  35. for chain in chains:
  36. # do not add handler into singleton callback manager
  37. if not isinstance(chain.callback_manager, SharedCallbackManager):
  38. chain.callback_manager.add_handler(chain_callback_handler)
  39. # build main chain
  40. overall_chain = SequentialChain(
  41. chains=chains,
  42. input_variables=[first_input_key],
  43. output_variables=[final_output_key],
  44. memory=memory, # only for use the memory prompt input key
  45. )
  46. return overall_chain
  47. @classmethod
  48. def get_agent_chains(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
  49. dataset_tool_callback_handler: DatasetToolCallbackHandler,
  50. agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler):
  51. # agent mode
  52. chains = []
  53. if agent_mode and agent_mode.get('enabled'):
  54. tools = agent_mode.get('tools', [])
  55. pre_fixed_chains = []
  56. agent_tools = []
  57. for tool in tools:
  58. tool_type = list(tool.keys())[0]
  59. tool_config = list(tool.values())[0]
  60. if tool_type == 'sensitive-word-avoidance':
  61. chain = ChainBuilder.to_sensitive_word_avoidance_chain(tool_config)
  62. if chain:
  63. pre_fixed_chains.append(chain)
  64. elif tool_type == "dataset":
  65. dataset_tool = DatasetToolBuilder.build_dataset_tool(
  66. tenant_id=tenant_id,
  67. dataset_id=tool_config.get("id"),
  68. response_mode='no_synthesizer', # "compact"
  69. callback_handler=dataset_tool_callback_handler
  70. )
  71. if dataset_tool:
  72. agent_tools.append(dataset_tool)
  73. # add pre-fixed chains
  74. chains += pre_fixed_chains
  75. if len(agent_tools) == 1:
  76. # tool to chain
  77. tool_chain = ChainBuilder.to_tool_chain(tool=agent_tools[0], output_key='tool_output')
  78. chains.append(tool_chain)
  79. elif len(agent_tools) > 1:
  80. # build agent config
  81. agent_chain = AgentBuilder.to_agent_chain(
  82. tenant_id=tenant_id,
  83. tools=agent_tools,
  84. memory=memory,
  85. dataset_tool_callback_handler=dataset_tool_callback_handler,
  86. agent_loop_gather_callback_handler=agent_loop_gather_callback_handler
  87. )
  88. chains.append(agent_chain)
  89. final_output_key = cls.get_chains_output_key(chains)
  90. return chains, final_output_key
  91. @classmethod
  92. def get_chains_output_key(cls, chains: List[Chain]):
  93. if len(chains) > 0:
  94. return chains[-1].output_keys[0]
  95. return None