prompt_builder.py 996 B

123456789101112131415161718192021222324
  1. from langchain.schema import BaseMessage, SystemMessage, AIMessage, HumanMessage
  2. from core.prompt.prompt_template import PromptTemplateParser
  3. class PromptBuilder:
  4. @classmethod
  5. def parse_prompt(cls, prompt: str, inputs: dict) -> str:
  6. prompt_template = PromptTemplateParser(prompt)
  7. prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
  8. prompt = prompt_template.format(prompt_inputs)
  9. return prompt
  10. @classmethod
  11. def to_system_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
  12. return SystemMessage(content=cls.parse_prompt(prompt_content, inputs))
  13. @classmethod
  14. def to_ai_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
  15. return AIMessage(content=cls.parse_prompt(prompt_content, inputs))
  16. @classmethod
  17. def to_human_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
  18. return HumanMessage(content=cls.parse_prompt(prompt_content, inputs))