prompt_transform.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. import json
  2. import os
  3. import re
  4. import enum
  5. from typing import List, Optional, Tuple
  6. from langchain.memory.chat_memory import BaseChatMemory
  7. from langchain.schema import BaseMessage
  8. from core.model_providers.models.entity.model_params import ModelMode
  9. from core.model_providers.models.entity.message import PromptMessage, MessageType, to_prompt_messages, PromptMessageFile
  10. from core.model_providers.models.llm.base import BaseLLM
  11. from core.model_providers.models.llm.baichuan_model import BaichuanModel
  12. from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel
  13. from core.model_providers.models.llm.openllm_model import OpenLLMModel
  14. from core.model_providers.models.llm.xinference_model import XinferenceModel
  15. from core.prompt.prompt_builder import PromptBuilder
  16. from core.prompt.prompt_template import PromptTemplateParser
  17. from models.model import AppModelConfig
  18. class AppMode(enum.Enum):
  19. COMPLETION = 'completion'
  20. CHAT = 'chat'
  21. class PromptTransform:
  22. def get_prompt(self,
  23. app_mode: str,
  24. app_model_config: AppModelConfig,
  25. pre_prompt: str,
  26. inputs: dict,
  27. query: str,
  28. files: List[PromptMessageFile],
  29. context: Optional[str],
  30. memory: Optional[BaseChatMemory],
  31. model_instance: BaseLLM) -> \
  32. Tuple[List[PromptMessage], Optional[List[str]]]:
  33. model_mode = app_model_config.model_dict['mode']
  34. app_mode_enum = AppMode(app_mode)
  35. model_mode_enum = ModelMode(model_mode)
  36. prompt_rules = self._read_prompt_rules_from_file(self._prompt_file_name(app_mode, model_instance))
  37. if app_mode_enum == AppMode.CHAT and model_mode_enum == ModelMode.CHAT:
  38. stops = None
  39. prompt_messages = self._get_simple_chat_app_chat_model_prompt_messages(prompt_rules, pre_prompt, inputs,
  40. query, context, memory,
  41. model_instance, files)
  42. else:
  43. stops = prompt_rules.get('stops')
  44. if stops is not None and len(stops) == 0:
  45. stops = None
  46. prompt_messages = self._get_simple_others_prompt_messages(prompt_rules, pre_prompt, inputs, query, context,
  47. memory,
  48. model_instance, files)
  49. return prompt_messages, stops
  50. def get_advanced_prompt(self,
  51. app_mode: str,
  52. app_model_config: AppModelConfig,
  53. inputs: dict,
  54. query: str,
  55. files: List[PromptMessageFile],
  56. context: Optional[str],
  57. memory: Optional[BaseChatMemory],
  58. model_instance: BaseLLM) -> List[PromptMessage]:
  59. model_mode = app_model_config.model_dict['mode']
  60. app_mode_enum = AppMode(app_mode)
  61. model_mode_enum = ModelMode(model_mode)
  62. prompt_messages = []
  63. if app_mode_enum == AppMode.CHAT:
  64. if model_mode_enum == ModelMode.COMPLETION:
  65. prompt_messages = self._get_chat_app_completion_model_prompt_messages(app_model_config, inputs, query,
  66. files, context, memory,
  67. model_instance)
  68. elif model_mode_enum == ModelMode.CHAT:
  69. prompt_messages = self._get_chat_app_chat_model_prompt_messages(app_model_config, inputs, query, files,
  70. context, memory, model_instance)
  71. elif app_mode_enum == AppMode.COMPLETION:
  72. if model_mode_enum == ModelMode.CHAT:
  73. prompt_messages = self._get_completion_app_chat_model_prompt_messages(app_model_config, inputs,
  74. files, context)
  75. elif model_mode_enum == ModelMode.COMPLETION:
  76. prompt_messages = self._get_completion_app_completion_model_prompt_messages(app_model_config, inputs,
  77. files, context)
  78. return prompt_messages
  79. def _get_history_messages_from_memory(self, memory: BaseChatMemory,
  80. max_token_limit: int) -> str:
  81. """Get memory messages."""
  82. memory.max_token_limit = max_token_limit
  83. memory_key = memory.memory_variables[0]
  84. external_context = memory.load_memory_variables({})
  85. return external_context[memory_key]
  86. def _get_history_messages_list_from_memory(self, memory: BaseChatMemory,
  87. max_token_limit: int) -> List[PromptMessage]:
  88. """Get memory messages."""
  89. memory.max_token_limit = max_token_limit
  90. memory.return_messages = True
  91. memory_key = memory.memory_variables[0]
  92. external_context = memory.load_memory_variables({})
  93. memory.return_messages = False
  94. return to_prompt_messages(external_context[memory_key])
  95. def _prompt_file_name(self, mode: str, model_instance: BaseLLM) -> str:
  96. # baichuan
  97. if isinstance(model_instance, BaichuanModel):
  98. return self._prompt_file_name_for_baichuan(mode)
  99. baichuan_model_hosted_platforms = (HuggingfaceHubModel, OpenLLMModel, XinferenceModel)
  100. if isinstance(model_instance, baichuan_model_hosted_platforms) and 'baichuan' in model_instance.name.lower():
  101. return self._prompt_file_name_for_baichuan(mode)
  102. # common
  103. if mode == 'completion':
  104. return 'common_completion'
  105. else:
  106. return 'common_chat'
  107. def _prompt_file_name_for_baichuan(self, mode: str) -> str:
  108. if mode == 'completion':
  109. return 'baichuan_completion'
  110. else:
  111. return 'baichuan_chat'
  112. def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
  113. # Get the absolute path of the subdirectory
  114. prompt_path = os.path.join(
  115. os.path.dirname(os.path.realpath(__file__)),
  116. 'generate_prompts')
  117. json_file_path = os.path.join(prompt_path, f'{prompt_name}.json')
  118. # Open the JSON file and read its content
  119. with open(json_file_path, 'r') as json_file:
  120. return json.load(json_file)
  121. def _get_simple_chat_app_chat_model_prompt_messages(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
  122. query: str,
  123. context: Optional[str],
  124. memory: Optional[BaseChatMemory],
  125. model_instance: BaseLLM,
  126. files: List[PromptMessageFile]) -> List[PromptMessage]:
  127. prompt_messages = []
  128. context_prompt_content = ''
  129. if context and 'context_prompt' in prompt_rules:
  130. prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
  131. context_prompt_content = prompt_template.format(
  132. {'context': context}
  133. )
  134. pre_prompt_content = ''
  135. if pre_prompt:
  136. prompt_template = PromptTemplateParser(template=pre_prompt)
  137. prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
  138. pre_prompt_content = prompt_template.format(
  139. prompt_inputs
  140. )
  141. prompt = ''
  142. for order in prompt_rules['system_prompt_orders']:
  143. if order == 'context_prompt':
  144. prompt += context_prompt_content
  145. elif order == 'pre_prompt':
  146. prompt += pre_prompt_content
  147. prompt = re.sub(r'<\|.*?\|>', '', prompt)
  148. prompt_messages.append(PromptMessage(type=MessageType.SYSTEM, content=prompt))
  149. self._append_chat_histories(memory, prompt_messages, model_instance)
  150. prompt_messages.append(PromptMessage(type=MessageType.USER, content=query, files=files))
  151. return prompt_messages
  152. def _get_simple_others_prompt_messages(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
  153. query: str,
  154. context: Optional[str],
  155. memory: Optional[BaseChatMemory],
  156. model_instance: BaseLLM,
  157. files: List[PromptMessageFile]) -> List[PromptMessage]:
  158. context_prompt_content = ''
  159. if context and 'context_prompt' in prompt_rules:
  160. prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
  161. context_prompt_content = prompt_template.format(
  162. {'context': context}
  163. )
  164. pre_prompt_content = ''
  165. if pre_prompt:
  166. prompt_template = PromptTemplateParser(template=pre_prompt)
  167. prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
  168. pre_prompt_content = prompt_template.format(
  169. prompt_inputs
  170. )
  171. prompt = ''
  172. for order in prompt_rules['system_prompt_orders']:
  173. if order == 'context_prompt':
  174. prompt += context_prompt_content
  175. elif order == 'pre_prompt':
  176. prompt += pre_prompt_content
  177. query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}'
  178. if memory and 'histories_prompt' in prompt_rules:
  179. # append chat histories
  180. tmp_human_message = PromptBuilder.to_human_message(
  181. prompt_content=prompt + query_prompt,
  182. inputs={
  183. 'query': query
  184. }
  185. )
  186. rest_tokens = self._calculate_rest_token(tmp_human_message, model_instance)
  187. memory.human_prefix = prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human'
  188. memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
  189. histories = self._get_history_messages_from_memory(memory, rest_tokens)
  190. prompt_template = PromptTemplateParser(template=prompt_rules['histories_prompt'])
  191. histories_prompt_content = prompt_template.format({'histories': histories})
  192. prompt = ''
  193. for order in prompt_rules['system_prompt_orders']:
  194. if order == 'context_prompt':
  195. prompt += context_prompt_content
  196. elif order == 'pre_prompt':
  197. prompt += (pre_prompt_content + '\n') if pre_prompt_content else ''
  198. elif order == 'histories_prompt':
  199. prompt += histories_prompt_content
  200. prompt_template = PromptTemplateParser(template=query_prompt)
  201. query_prompt_content = prompt_template.format({'query': query})
  202. prompt += query_prompt_content
  203. prompt = re.sub(r'<\|.*?\|>', '', prompt)
  204. return [PromptMessage(content=prompt, files=files)]
  205. def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None:
  206. if '#context#' in prompt_template.variable_keys:
  207. if context:
  208. prompt_inputs['#context#'] = context
  209. else:
  210. prompt_inputs['#context#'] = ''
  211. def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None:
  212. if '#query#' in prompt_template.variable_keys:
  213. if query:
  214. prompt_inputs['#query#'] = query
  215. else:
  216. prompt_inputs['#query#'] = ''
  217. def _set_histories_variable(self, memory: BaseChatMemory, raw_prompt: str, conversation_histories_role: dict,
  218. prompt_template: PromptTemplateParser, prompt_inputs: dict,
  219. model_instance: BaseLLM) -> None:
  220. if '#histories#' in prompt_template.variable_keys:
  221. if memory:
  222. tmp_human_message = PromptBuilder.to_human_message(
  223. prompt_content=raw_prompt,
  224. inputs={'#histories#': '', **prompt_inputs}
  225. )
  226. rest_tokens = self._calculate_rest_token(tmp_human_message, model_instance)
  227. memory.human_prefix = conversation_histories_role['user_prefix']
  228. memory.ai_prefix = conversation_histories_role['assistant_prefix']
  229. histories = self._get_history_messages_from_memory(memory, rest_tokens)
  230. prompt_inputs['#histories#'] = histories
  231. else:
  232. prompt_inputs['#histories#'] = ''
  233. def _append_chat_histories(self, memory: BaseChatMemory, prompt_messages: list[PromptMessage],
  234. model_instance: BaseLLM) -> None:
  235. if memory:
  236. rest_tokens = self._calculate_rest_token(prompt_messages, model_instance)
  237. memory.human_prefix = MessageType.USER.value
  238. memory.ai_prefix = MessageType.ASSISTANT.value
  239. histories = self._get_history_messages_list_from_memory(memory, rest_tokens)
  240. prompt_messages.extend(histories)
  241. def _calculate_rest_token(self, prompt_messages: BaseMessage, model_instance: BaseLLM) -> int:
  242. rest_tokens = 2000
  243. if model_instance.model_rules.max_tokens.max:
  244. curr_message_tokens = model_instance.get_num_tokens(to_prompt_messages(prompt_messages))
  245. max_tokens = model_instance.model_kwargs.max_tokens
  246. rest_tokens = model_instance.model_rules.max_tokens.max - max_tokens - curr_message_tokens
  247. rest_tokens = max(rest_tokens, 0)
  248. return rest_tokens
  249. def _format_prompt(self, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> str:
  250. prompt = prompt_template.format(
  251. prompt_inputs
  252. )
  253. prompt = re.sub(r'<\|.*?\|>', '', prompt)
  254. return prompt
  255. def _get_chat_app_completion_model_prompt_messages(self,
  256. app_model_config: AppModelConfig,
  257. inputs: dict,
  258. query: str,
  259. files: List[PromptMessageFile],
  260. context: Optional[str],
  261. memory: Optional[BaseChatMemory],
  262. model_instance: BaseLLM) -> List[PromptMessage]:
  263. raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text']
  264. conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role']
  265. prompt_messages = []
  266. prompt_template = PromptTemplateParser(template=raw_prompt)
  267. prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
  268. self._set_context_variable(context, prompt_template, prompt_inputs)
  269. self._set_query_variable(query, prompt_template, prompt_inputs)
  270. self._set_histories_variable(memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs,
  271. model_instance)
  272. prompt = self._format_prompt(prompt_template, prompt_inputs)
  273. prompt_messages.append(PromptMessage(type=MessageType.USER, content=prompt, files=files))
  274. return prompt_messages
  275. def _get_chat_app_chat_model_prompt_messages(self,
  276. app_model_config: AppModelConfig,
  277. inputs: dict,
  278. query: str,
  279. files: List[PromptMessageFile],
  280. context: Optional[str],
  281. memory: Optional[BaseChatMemory],
  282. model_instance: BaseLLM) -> List[PromptMessage]:
  283. raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
  284. prompt_messages = []
  285. for prompt_item in raw_prompt_list:
  286. raw_prompt = prompt_item['text']
  287. prompt_template = PromptTemplateParser(template=raw_prompt)
  288. prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
  289. self._set_context_variable(context, prompt_template, prompt_inputs)
  290. prompt = self._format_prompt(prompt_template, prompt_inputs)
  291. prompt_messages.append(PromptMessage(type=MessageType(prompt_item['role']), content=prompt))
  292. self._append_chat_histories(memory, prompt_messages, model_instance)
  293. prompt_messages.append(PromptMessage(type=MessageType.USER, content=query, files=files))
  294. return prompt_messages
  295. def _get_completion_app_completion_model_prompt_messages(self,
  296. app_model_config: AppModelConfig,
  297. inputs: dict,
  298. files: List[PromptMessageFile],
  299. context: Optional[str]) -> List[PromptMessage]:
  300. raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text']
  301. prompt_messages = []
  302. prompt_template = PromptTemplateParser(template=raw_prompt)
  303. prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
  304. self._set_context_variable(context, prompt_template, prompt_inputs)
  305. prompt = self._format_prompt(prompt_template, prompt_inputs)
  306. prompt_messages.append(PromptMessage(type=MessageType(MessageType.USER), content=prompt, files=files))
  307. return prompt_messages
  308. def _get_completion_app_chat_model_prompt_messages(self,
  309. app_model_config: AppModelConfig,
  310. inputs: dict,
  311. files: List[PromptMessageFile],
  312. context: Optional[str]) -> List[PromptMessage]:
  313. raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
  314. prompt_messages = []
  315. for prompt_item in raw_prompt_list:
  316. raw_prompt = prompt_item['text']
  317. prompt_template = PromptTemplateParser(template=raw_prompt)
  318. prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
  319. self._set_context_variable(context, prompt_template, prompt_inputs)
  320. prompt = self._format_prompt(prompt_template, prompt_inputs)
  321. prompt_messages.append(PromptMessage(type=MessageType(prompt_item['role']), content=prompt))
  322. for prompt_message in prompt_messages[::-1]:
  323. if prompt_message.type == MessageType.USER:
  324. prompt_message.files = files
  325. break
  326. return prompt_messages