advanced_prompt_template_service.py 3.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import copy
  2. from core.model_providers.models.entity.model_params import ModelMode
  3. from core.prompt.prompt_transform import AppMode
  4. from core.prompt.advanced_prompt_templates import CHAT_APP_COMPLETION_PROMPT_CONFIG, CHAT_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_COMPLETION_PROMPT_CONFIG, \
  5. BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, CONTEXT, BAICHUAN_CONTEXT
  6. class AdvancedPromptTemplateService:
  7. @classmethod
  8. def get_prompt(cls, args: dict) -> dict:
  9. app_mode = args['app_mode']
  10. model_mode = args['model_mode']
  11. model_name = args['model_name']
  12. has_context = args['has_context']
  13. if 'baichuan' in model_name.lower():
  14. return cls.get_baichuan_prompt(app_mode, model_mode, has_context)
  15. else:
  16. return cls.get_common_prompt(app_mode, model_mode, has_context)
  17. @classmethod
  18. def get_common_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict:
  19. context_prompt = copy.deepcopy(CONTEXT)
  20. if app_mode == AppMode.CHAT.value:
  21. if model_mode == ModelMode.COMPLETION.value:
  22. return cls.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt)
  23. elif model_mode == ModelMode.CHAT.value:
  24. return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
  25. elif app_mode == AppMode.COMPLETION.value:
  26. if model_mode == ModelMode.COMPLETION.value:
  27. return cls.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt)
  28. elif model_mode == ModelMode.CHAT.value:
  29. return cls.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
  30. @classmethod
  31. def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
  32. if has_context == 'true':
  33. prompt_template['completion_prompt_config']['prompt']['text'] = context + prompt_template['completion_prompt_config']['prompt']['text']
  34. return prompt_template
  35. @classmethod
  36. def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
  37. if has_context == 'true':
  38. prompt_template['chat_prompt_config']['prompt'][0]['text'] = context + prompt_template['chat_prompt_config']['prompt'][0]['text']
  39. return prompt_template
  40. @classmethod
  41. def get_baichuan_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict:
  42. baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)
  43. if app_mode == AppMode.CHAT.value:
  44. if model_mode == ModelMode.COMPLETION.value:
  45. return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt)
  46. elif model_mode == ModelMode.CHAT.value:
  47. return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt)
  48. elif app_mode == AppMode.COMPLETION.value:
  49. if model_mode == ModelMode.COMPLETION.value:
  50. return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt)
  51. elif model_mode == ModelMode.CHAT.value:
  52. return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt)