completion.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. import logging
  2. from typing import Optional, List, Union, Tuple
  3. from langchain.base_language import BaseLanguageModel
  4. from langchain.callbacks.base import BaseCallbackHandler
  5. from langchain.chat_models.base import BaseChatModel
  6. from langchain.llms import BaseLLM
  7. from langchain.schema import BaseMessage, HumanMessage
  8. from requests.exceptions import ChunkedEncodingError
  9. from core.constant import llm_constant
  10. from core.callback_handler.llm_callback_handler import LLMCallbackHandler
  11. from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
  12. DifyStdOutCallbackHandler
  13. from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
  14. from core.llm.error import LLMBadRequestError
  15. from core.llm.llm_builder import LLMBuilder
  16. from core.chain.main_chain_builder import MainChainBuilder
  17. from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
  18. from core.llm.streamable_open_ai import StreamableOpenAI
  19. from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
  20. ReadOnlyConversationTokenDBBufferSharedMemory
  21. from core.memory.read_only_conversation_token_db_string_buffer_shared_memory import \
  22. ReadOnlyConversationTokenDBStringBufferSharedMemory
  23. from core.prompt.prompt_builder import PromptBuilder
  24. from core.prompt.prompt_template import JinjaPromptTemplate
  25. from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
  26. from models.model import App, AppModelConfig, Account, Conversation, Message
  27. class Completion:
  28. @classmethod
  29. def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
  30. user: Account, conversation: Optional[Conversation], streaming: bool, is_override: bool = False):
  31. """
  32. errors: ProviderTokenNotInitError
  33. """
  34. query = PromptBuilder.process_template(query)
  35. memory = None
  36. if conversation:
  37. # get memory of conversation (read-only)
  38. memory = cls.get_memory_from_conversation(
  39. tenant_id=app.tenant_id,
  40. app_model_config=app_model_config,
  41. conversation=conversation,
  42. return_messages=False
  43. )
  44. inputs = conversation.inputs
  45. rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
  46. mode=app.mode,
  47. tenant_id=app.tenant_id,
  48. app_model_config=app_model_config,
  49. query=query,
  50. inputs=inputs
  51. )
  52. conversation_message_task = ConversationMessageTask(
  53. task_id=task_id,
  54. app=app,
  55. app_model_config=app_model_config,
  56. user=user,
  57. conversation=conversation,
  58. is_override=is_override,
  59. inputs=inputs,
  60. query=query,
  61. streaming=streaming
  62. )
  63. # build main chain include agent
  64. main_chain = MainChainBuilder.to_langchain_components(
  65. tenant_id=app.tenant_id,
  66. agent_mode=app_model_config.agent_mode_dict,
  67. rest_tokens=rest_tokens_for_context_and_memory,
  68. memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None,
  69. conversation_message_task=conversation_message_task
  70. )
  71. chain_output = ''
  72. if main_chain:
  73. chain_output = main_chain.run(query)
  74. # run the final llm
  75. try:
  76. cls.run_final_llm(
  77. tenant_id=app.tenant_id,
  78. mode=app.mode,
  79. app_model_config=app_model_config,
  80. query=query,
  81. inputs=inputs,
  82. chain_output=chain_output,
  83. conversation_message_task=conversation_message_task,
  84. memory=memory,
  85. streaming=streaming
  86. )
  87. except ConversationTaskStoppedException:
  88. return
  89. except ChunkedEncodingError as e:
  90. # Interrupt by LLM (like OpenAI), handle it.
  91. logging.warning(f'ChunkedEncodingError: {e}')
  92. conversation_message_task.end()
  93. return
  94. @classmethod
  95. def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
  96. chain_output: str,
  97. conversation_message_task: ConversationMessageTask,
  98. memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], streaming: bool):
  99. final_llm = LLMBuilder.to_llm_from_model(
  100. tenant_id=tenant_id,
  101. model=app_model_config.model_dict,
  102. streaming=streaming
  103. )
  104. # get llm prompt
  105. prompt, stop_words = cls.get_main_llm_prompt(
  106. mode=mode,
  107. llm=final_llm,
  108. model=app_model_config.model_dict,
  109. pre_prompt=app_model_config.pre_prompt,
  110. query=query,
  111. inputs=inputs,
  112. chain_output=chain_output,
  113. memory=memory
  114. )
  115. final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
  116. cls.recale_llm_max_tokens(
  117. final_llm=final_llm,
  118. model=app_model_config.model_dict,
  119. prompt=prompt,
  120. mode=mode
  121. )
  122. response = final_llm.generate([prompt], stop_words)
  123. return response
  124. @classmethod
  125. def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, model: dict,
  126. pre_prompt: str, query: str, inputs: dict,
  127. chain_output: Optional[str],
  128. memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
  129. Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
  130. # disable template string in query
  131. # query_params = JinjaPromptTemplate.from_template(template=query).input_variables
  132. # if query_params:
  133. # for query_param in query_params:
  134. # if query_param not in inputs:
  135. # inputs[query_param] = '{{' + query_param + '}}'
  136. if mode == 'completion':
  137. prompt_template = JinjaPromptTemplate.from_template(
  138. template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
  139. <context>
  140. {{context}}
  141. </context>
  142. When answer to user:
  143. - If you don't know, just say that you don't know.
  144. - If you don't know when you are not sure, ask for clarification.
  145. Avoid mentioning that you obtained the information from the context.
  146. And answer according to the language of the user's question.
  147. """ if chain_output else "")
  148. + (pre_prompt + "\n" if pre_prompt else "")
  149. + "{{query}}\n"
  150. )
  151. if chain_output:
  152. inputs['context'] = chain_output
  153. # context_params = JinjaPromptTemplate.from_template(template=chain_output).input_variables
  154. # if context_params:
  155. # for context_param in context_params:
  156. # if context_param not in inputs:
  157. # inputs[context_param] = '{{' + context_param + '}}'
  158. prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
  159. prompt_content = prompt_template.format(
  160. query=query,
  161. **prompt_inputs
  162. )
  163. if isinstance(llm, BaseChatModel):
  164. # use chat llm as completion model
  165. return [HumanMessage(content=prompt_content)], None
  166. else:
  167. return prompt_content, None
  168. else:
  169. messages: List[BaseMessage] = []
  170. human_inputs = {
  171. "query": query
  172. }
  173. human_message_prompt = ""
  174. if pre_prompt:
  175. pre_prompt_inputs = {k: inputs[k] for k in
  176. JinjaPromptTemplate.from_template(template=pre_prompt).input_variables
  177. if k in inputs}
  178. if pre_prompt_inputs:
  179. human_inputs.update(pre_prompt_inputs)
  180. if chain_output:
  181. human_inputs['context'] = chain_output
  182. human_message_prompt += """Use the following context as your learned knowledge, inside <context></context> XML tags.
  183. <context>
  184. {{context}}
  185. </context>
  186. When answer to user:
  187. - If you don't know, just say that you don't know.
  188. - If you don't know when you are not sure, ask for clarification.
  189. Avoid mentioning that you obtained the information from the context.
  190. And answer according to the language of the user's question.
  191. """
  192. if pre_prompt:
  193. human_message_prompt += pre_prompt
  194. query_prompt = "\n\nHuman: {{query}}\n\nAssistant: "
  195. if memory:
  196. # append chat histories
  197. tmp_human_message = PromptBuilder.to_human_message(
  198. prompt_content=human_message_prompt + query_prompt,
  199. inputs=human_inputs
  200. )
  201. curr_message_tokens = memory.llm.get_num_tokens_from_messages([tmp_human_message])
  202. model_name = model['name']
  203. max_tokens = model.get("completion_params").get('max_tokens')
  204. rest_tokens = llm_constant.max_context_token_length[model_name] \
  205. - max_tokens - curr_message_tokens
  206. rest_tokens = max(rest_tokens, 0)
  207. histories = cls.get_history_messages_from_memory(memory, rest_tokens)
  208. # disable template string in query
  209. # histories_params = JinjaPromptTemplate.from_template(template=histories).input_variables
  210. # if histories_params:
  211. # for histories_param in histories_params:
  212. # if histories_param not in human_inputs:
  213. # human_inputs[histories_param] = '{{' + histories_param + '}}'
  214. human_message_prompt += "\n\n" if human_message_prompt else ""
  215. human_message_prompt += "Here is the chat histories between human and assistant, " \
  216. "inside <histories></histories> XML tags.\n\n<histories>"
  217. human_message_prompt += histories + "</histories>"
  218. human_message_prompt += query_prompt
  219. # construct main prompt
  220. human_message = PromptBuilder.to_human_message(
  221. prompt_content=human_message_prompt,
  222. inputs=human_inputs
  223. )
  224. messages.append(human_message)
  225. return messages, ['\nHuman:']
  226. @classmethod
  227. def get_llm_callbacks(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
  228. streaming: bool,
  229. conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]:
  230. llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
  231. if streaming:
  232. return [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
  233. else:
  234. return [llm_callback_handler, DifyStdOutCallbackHandler()]
  235. @classmethod
  236. def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
  237. max_token_limit: int) -> \
  238. str:
  239. """Get memory messages."""
  240. memory.max_token_limit = max_token_limit
  241. memory_key = memory.memory_variables[0]
  242. external_context = memory.load_memory_variables({})
  243. return external_context[memory_key]
  244. @classmethod
  245. def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,
  246. conversation: Conversation,
  247. **kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory:
  248. # only for calc token in memory
  249. memory_llm = LLMBuilder.to_llm_from_model(
  250. tenant_id=tenant_id,
  251. model=app_model_config.model_dict
  252. )
  253. # use llm config from conversation
  254. memory = ReadOnlyConversationTokenDBBufferSharedMemory(
  255. conversation=conversation,
  256. llm=memory_llm,
  257. max_token_limit=kwargs.get("max_token_limit", 2048),
  258. memory_key=kwargs.get("memory_key", "chat_history"),
  259. return_messages=kwargs.get("return_messages", True),
  260. input_key=kwargs.get("input_key", "input"),
  261. output_key=kwargs.get("output_key", "output"),
  262. message_limit=kwargs.get("message_limit", 10),
  263. )
  264. return memory
  265. @classmethod
  266. def get_validate_rest_tokens(cls, mode: str, tenant_id: str, app_model_config: AppModelConfig,
  267. query: str, inputs: dict) -> int:
  268. llm = LLMBuilder.to_llm_from_model(
  269. tenant_id=tenant_id,
  270. model=app_model_config.model_dict
  271. )
  272. model_name = app_model_config.model_dict.get("name")
  273. model_limited_tokens = llm_constant.max_context_token_length[model_name]
  274. max_tokens = app_model_config.model_dict.get("completion_params").get('max_tokens')
  275. # get prompt without memory and context
  276. prompt, _ = cls.get_main_llm_prompt(
  277. mode=mode,
  278. llm=llm,
  279. model=app_model_config.model_dict,
  280. pre_prompt=app_model_config.pre_prompt,
  281. query=query,
  282. inputs=inputs,
  283. chain_output=None,
  284. memory=None
  285. )
  286. prompt_tokens = llm.get_num_tokens(prompt) if isinstance(prompt, str) \
  287. else llm.get_num_tokens_from_messages(prompt)
  288. rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
  289. if rest_tokens < 0:
  290. raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
  291. "or shrink the max token, or switch to a llm with a larger token limit size.")
  292. return rest_tokens
  293. @classmethod
  294. def recale_llm_max_tokens(cls, final_llm: BaseLanguageModel, model: dict,
  295. prompt: Union[str, List[BaseMessage]], mode: str):
  296. # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
  297. model_name = model.get("name")
  298. model_limited_tokens = llm_constant.max_context_token_length[model_name]
  299. max_tokens = model.get("completion_params").get('max_tokens')
  300. if mode == 'completion' and isinstance(final_llm, BaseLLM):
  301. prompt_tokens = final_llm.get_num_tokens(prompt)
  302. else:
  303. prompt_tokens = final_llm.get_num_tokens_from_messages(prompt)
  304. if prompt_tokens + max_tokens > model_limited_tokens:
  305. max_tokens = max(model_limited_tokens - prompt_tokens, 16)
  306. final_llm.max_tokens = max_tokens
  307. @classmethod
  308. def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
  309. app_model_config: AppModelConfig, user: Account, streaming: bool):
  310. llm = LLMBuilder.to_llm_from_model(
  311. tenant_id=app.tenant_id,
  312. model=app_model_config.model_dict,
  313. streaming=streaming
  314. )
  315. # get llm prompt
  316. original_prompt, _ = cls.get_main_llm_prompt(
  317. mode="completion",
  318. llm=llm,
  319. model=app_model_config.model_dict,
  320. pre_prompt=pre_prompt,
  321. query=message.query,
  322. inputs=message.inputs,
  323. chain_output=None,
  324. memory=None
  325. )
  326. original_completion = message.answer.strip()
  327. prompt = MORE_LIKE_THIS_GENERATE_PROMPT
  328. prompt = prompt.format(prompt=original_prompt, original_completion=original_completion)
  329. if isinstance(llm, BaseChatModel):
  330. prompt = [HumanMessage(content=prompt)]
  331. conversation_message_task = ConversationMessageTask(
  332. task_id=task_id,
  333. app=app,
  334. app_model_config=app_model_config,
  335. user=user,
  336. inputs=message.inputs,
  337. query=message.query,
  338. is_override=True if message.override_model_configs else False,
  339. streaming=streaming
  340. )
  341. llm.callbacks = cls.get_llm_callbacks(llm, streaming, conversation_message_task)
  342. cls.recale_llm_max_tokens(
  343. final_llm=llm,
  344. model=app_model_config.model_dict,
  345. prompt=prompt,
  346. mode='completion'
  347. )
  348. llm.generate([prompt])