completion.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. from typing import Optional, List, Union
  2. from langchain.callbacks import CallbackManager
  3. from langchain.chat_models.base import BaseChatModel
  4. from langchain.llms import BaseLLM
  5. from langchain.schema import BaseMessage, BaseLanguageModel, HumanMessage
  6. from core.constant import llm_constant
  7. from core.callback_handler.llm_callback_handler import LLMCallbackHandler
  8. from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
  9. DifyStdOutCallbackHandler
  10. from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
  11. from core.llm.error import LLMBadRequestError
  12. from core.llm.llm_builder import LLMBuilder
  13. from core.chain.main_chain_builder import MainChainBuilder
  14. from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
  15. from core.llm.streamable_open_ai import StreamableOpenAI
  16. from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
  17. ReadOnlyConversationTokenDBBufferSharedMemory
  18. from core.memory.read_only_conversation_token_db_string_buffer_shared_memory import \
  19. ReadOnlyConversationTokenDBStringBufferSharedMemory
  20. from core.prompt.prompt_builder import PromptBuilder
  21. from core.prompt.prompt_template import OutLinePromptTemplate
  22. from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
  23. from models.model import App, AppModelConfig, Account, Conversation, Message
  24. class Completion:
  25. @classmethod
  26. def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
  27. user: Account, conversation: Optional[Conversation], streaming: bool, is_override: bool = False):
  28. """
  29. errors: ProviderTokenNotInitError
  30. """
  31. cls.validate_query_tokens(app.tenant_id, app_model_config, query)
  32. memory = None
  33. if conversation:
  34. # get memory of conversation (read-only)
  35. memory = cls.get_memory_from_conversation(
  36. tenant_id=app.tenant_id,
  37. app_model_config=app_model_config,
  38. conversation=conversation
  39. )
  40. inputs = conversation.inputs
  41. conversation_message_task = ConversationMessageTask(
  42. task_id=task_id,
  43. app=app,
  44. app_model_config=app_model_config,
  45. user=user,
  46. conversation=conversation,
  47. is_override=is_override,
  48. inputs=inputs,
  49. query=query,
  50. streaming=streaming
  51. )
  52. # build main chain include agent
  53. main_chain = MainChainBuilder.to_langchain_components(
  54. tenant_id=app.tenant_id,
  55. agent_mode=app_model_config.agent_mode_dict,
  56. memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None,
  57. conversation_message_task=conversation_message_task
  58. )
  59. chain_output = ''
  60. if main_chain:
  61. chain_output = main_chain.run(query)
  62. # run the final llm
  63. try:
  64. cls.run_final_llm(
  65. tenant_id=app.tenant_id,
  66. mode=app.mode,
  67. app_model_config=app_model_config,
  68. query=query,
  69. inputs=inputs,
  70. chain_output=chain_output,
  71. conversation_message_task=conversation_message_task,
  72. memory=memory,
  73. streaming=streaming
  74. )
  75. except ConversationTaskStoppedException:
  76. return
  77. @classmethod
  78. def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
  79. chain_output: str,
  80. conversation_message_task: ConversationMessageTask,
  81. memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], streaming: bool):
  82. final_llm = LLMBuilder.to_llm_from_model(
  83. tenant_id=tenant_id,
  84. model=app_model_config.model_dict,
  85. streaming=streaming
  86. )
  87. # get llm prompt
  88. prompt = cls.get_main_llm_prompt(
  89. mode=mode,
  90. llm=final_llm,
  91. pre_prompt=app_model_config.pre_prompt,
  92. query=query,
  93. inputs=inputs,
  94. chain_output=chain_output,
  95. memory=memory
  96. )
  97. final_llm.callback_manager = cls.get_llm_callback_manager(final_llm, streaming, conversation_message_task)
  98. cls.recale_llm_max_tokens(
  99. final_llm=final_llm,
  100. prompt=prompt,
  101. mode=mode
  102. )
  103. response = final_llm.generate([prompt])
  104. return response
  105. @classmethod
  106. def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict, chain_output: Optional[str],
  107. memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
  108. Union[str | List[BaseMessage]]:
  109. pre_prompt = PromptBuilder.process_template(pre_prompt) if pre_prompt else pre_prompt
  110. if mode == 'completion':
  111. prompt_template = OutLinePromptTemplate.from_template(
  112. template=("Use the following pieces of [CONTEXT] to answer the question at the end. "
  113. "If you don't know the answer, "
  114. "just say that you don't know, don't try to make up an answer. \n"
  115. "```\n"
  116. "[CONTEXT]\n"
  117. "{context}\n"
  118. "```\n" if chain_output else "")
  119. + (pre_prompt + "\n" if pre_prompt else "")
  120. + "{query}\n"
  121. )
  122. if chain_output:
  123. inputs['context'] = chain_output
  124. prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
  125. prompt_content = prompt_template.format(
  126. query=query,
  127. **prompt_inputs
  128. )
  129. if isinstance(llm, BaseChatModel):
  130. # use chat llm as completion model
  131. return [HumanMessage(content=prompt_content)]
  132. else:
  133. return prompt_content
  134. else:
  135. messages: List[BaseMessage] = []
  136. system_message = None
  137. if pre_prompt:
  138. # append pre prompt as system message
  139. system_message = PromptBuilder.to_system_message(pre_prompt, inputs)
  140. if chain_output:
  141. # append context as system message, currently only use simple stuff prompt
  142. context_message = PromptBuilder.to_system_message(
  143. """Use the following pieces of [CONTEXT] to answer the users question.
  144. If you don't know the answer, just say that you don't know, don't try to make up an answer.
  145. ```
  146. [CONTEXT]
  147. {context}
  148. ```""",
  149. {'context': chain_output}
  150. )
  151. if not system_message:
  152. system_message = context_message
  153. else:
  154. system_message.content = context_message.content + "\n\n" + system_message.content
  155. if system_message:
  156. messages.append(system_message)
  157. human_inputs = {
  158. "query": query
  159. }
  160. # construct main prompt
  161. human_message = PromptBuilder.to_human_message(
  162. prompt_content="{query}",
  163. inputs=human_inputs
  164. )
  165. if memory:
  166. # append chat histories
  167. tmp_messages = messages.copy() + [human_message]
  168. curr_message_tokens = memory.llm.get_messages_tokens(tmp_messages)
  169. rest_tokens = llm_constant.max_context_token_length[
  170. memory.llm.model_name] - memory.llm.max_tokens - curr_message_tokens
  171. rest_tokens = max(rest_tokens, 0)
  172. history_messages = cls.get_history_messages_from_memory(memory, rest_tokens)
  173. messages += history_messages
  174. messages.append(human_message)
  175. return messages
  176. @classmethod
  177. def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
  178. streaming: bool, conversation_message_task: ConversationMessageTask) -> CallbackManager:
  179. llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
  180. if streaming:
  181. callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
  182. else:
  183. callback_handlers = [llm_callback_handler, DifyStdOutCallbackHandler()]
  184. return CallbackManager(callback_handlers)
  185. @classmethod
  186. def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
  187. max_token_limit: int) -> \
  188. List[BaseMessage]:
  189. """Get memory messages."""
  190. memory.max_token_limit = max_token_limit
  191. memory_key = memory.memory_variables[0]
  192. external_context = memory.load_memory_variables({})
  193. return external_context[memory_key]
  194. @classmethod
  195. def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,
  196. conversation: Conversation,
  197. **kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory:
  198. # only for calc token in memory
  199. memory_llm = LLMBuilder.to_llm_from_model(
  200. tenant_id=tenant_id,
  201. model=app_model_config.model_dict
  202. )
  203. # use llm config from conversation
  204. memory = ReadOnlyConversationTokenDBBufferSharedMemory(
  205. conversation=conversation,
  206. llm=memory_llm,
  207. max_token_limit=kwargs.get("max_token_limit", 2048),
  208. memory_key=kwargs.get("memory_key", "chat_history"),
  209. return_messages=kwargs.get("return_messages", True),
  210. input_key=kwargs.get("input_key", "input"),
  211. output_key=kwargs.get("output_key", "output"),
  212. message_limit=kwargs.get("message_limit", 10),
  213. )
  214. return memory
  215. @classmethod
  216. def validate_query_tokens(cls, tenant_id: str, app_model_config: AppModelConfig, query: str):
  217. llm = LLMBuilder.to_llm_from_model(
  218. tenant_id=tenant_id,
  219. model=app_model_config.model_dict
  220. )
  221. model_limited_tokens = llm_constant.max_context_token_length[llm.model_name]
  222. max_tokens = llm.max_tokens
  223. if model_limited_tokens - max_tokens - llm.get_num_tokens(query) < 0:
  224. raise LLMBadRequestError("Query is too long")
  225. @classmethod
  226. def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI],
  227. prompt: Union[str, List[BaseMessage]], mode: str):
  228. # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
  229. model_limited_tokens = llm_constant.max_context_token_length[final_llm.model_name]
  230. max_tokens = final_llm.max_tokens
  231. if mode == 'completion' and isinstance(final_llm, BaseLLM):
  232. prompt_tokens = final_llm.get_num_tokens(prompt)
  233. else:
  234. prompt_tokens = final_llm.get_messages_tokens(prompt)
  235. if prompt_tokens + max_tokens > model_limited_tokens:
  236. max_tokens = max(model_limited_tokens - prompt_tokens, 16)
  237. final_llm.max_tokens = max_tokens
  238. @classmethod
  239. def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
  240. app_model_config: AppModelConfig, user: Account, streaming: bool):
  241. llm: StreamableOpenAI = LLMBuilder.to_llm(
  242. tenant_id=app.tenant_id,
  243. model_name='gpt-3.5-turbo',
  244. streaming=streaming
  245. )
  246. # get llm prompt
  247. original_prompt = cls.get_main_llm_prompt(
  248. mode="completion",
  249. llm=llm,
  250. pre_prompt=pre_prompt,
  251. query=message.query,
  252. inputs=message.inputs,
  253. chain_output=None,
  254. memory=None
  255. )
  256. original_completion = message.answer.strip()
  257. prompt = MORE_LIKE_THIS_GENERATE_PROMPT
  258. prompt = prompt.format(prompt=original_prompt, original_completion=original_completion)
  259. if isinstance(llm, BaseChatModel):
  260. prompt = [HumanMessage(content=prompt)]
  261. conversation_message_task = ConversationMessageTask(
  262. task_id=task_id,
  263. app=app,
  264. app_model_config=app_model_config,
  265. user=user,
  266. inputs=message.inputs,
  267. query=message.query,
  268. is_override=True if message.override_model_configs else False,
  269. streaming=streaming
  270. )
  271. llm.callback_manager = cls.get_llm_callback_manager(llm, streaming, conversation_message_task)
  272. cls.recale_llm_max_tokens(
  273. final_llm=llm,
  274. prompt=prompt,
  275. mode='completion'
  276. )
  277. llm.generate([prompt])