completion.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. import json
  2. import logging
  3. import re
  4. from typing import Optional, List, Union, Tuple
  5. from langchain.schema import BaseMessage
  6. from requests.exceptions import ChunkedEncodingError
  7. from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
  8. from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
  9. from core.callback_handler.llm_callback_handler import LLMCallbackHandler
  10. from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
  11. from core.model_providers.error import LLMBadRequestError
  12. from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
  13. ReadOnlyConversationTokenDBBufferSharedMemory
  14. from core.model_providers.model_factory import ModelFactory
  15. from core.model_providers.models.entity.message import PromptMessage, to_prompt_messages
  16. from core.model_providers.models.llm.base import BaseLLM
  17. from core.orchestrator_rule_parser import OrchestratorRuleParser
  18. from core.prompt.prompt_builder import PromptBuilder
  19. from core.prompt.prompt_template import JinjaPromptTemplate
  20. from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
  21. from models.dataset import DocumentSegment, Dataset, Document
  22. from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
  23. class Completion:
  24. @classmethod
  25. def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
  26. user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool,
  27. is_override: bool = False, retriever_from: str = 'dev'):
  28. """
  29. errors: ProviderTokenNotInitError
  30. """
  31. query = PromptBuilder.process_template(query)
  32. memory = None
  33. if conversation:
  34. # get memory of conversation (read-only)
  35. memory = cls.get_memory_from_conversation(
  36. tenant_id=app.tenant_id,
  37. app_model_config=app_model_config,
  38. conversation=conversation,
  39. return_messages=False
  40. )
  41. inputs = conversation.inputs
  42. final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
  43. tenant_id=app.tenant_id,
  44. model_config=app_model_config.model_dict,
  45. streaming=streaming
  46. )
  47. conversation_message_task = ConversationMessageTask(
  48. task_id=task_id,
  49. app=app,
  50. app_model_config=app_model_config,
  51. user=user,
  52. conversation=conversation,
  53. is_override=is_override,
  54. inputs=inputs,
  55. query=query,
  56. streaming=streaming,
  57. model_instance=final_model_instance
  58. )
  59. rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
  60. mode=app.mode,
  61. model_instance=final_model_instance,
  62. app_model_config=app_model_config,
  63. query=query,
  64. inputs=inputs
  65. )
  66. # init orchestrator rule parser
  67. orchestrator_rule_parser = OrchestratorRuleParser(
  68. tenant_id=app.tenant_id,
  69. app_model_config=app_model_config
  70. )
  71. # parse sensitive_word_avoidance_chain
  72. chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
  73. sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback])
  74. if sensitive_word_avoidance_chain:
  75. query = sensitive_word_avoidance_chain.run(query)
  76. # get agent executor
  77. agent_executor = orchestrator_rule_parser.to_agent_executor(
  78. conversation_message_task=conversation_message_task,
  79. memory=memory,
  80. rest_tokens=rest_tokens_for_context_and_memory,
  81. chain_callback=chain_callback
  82. )
  83. # run agent executor
  84. agent_execute_result = None
  85. if agent_executor:
  86. should_use_agent = agent_executor.should_use_agent(query)
  87. if should_use_agent:
  88. agent_execute_result = agent_executor.run(query)
  89. # run the final llm
  90. try:
  91. cls.run_final_llm(
  92. model_instance=final_model_instance,
  93. mode=app.mode,
  94. app_model_config=app_model_config,
  95. query=query,
  96. inputs=inputs,
  97. agent_execute_result=agent_execute_result,
  98. conversation_message_task=conversation_message_task,
  99. memory=memory
  100. )
  101. except ConversationTaskStoppedException:
  102. return
  103. except ChunkedEncodingError as e:
  104. # Interrupt by LLM (like OpenAI), handle it.
  105. logging.warning(f'ChunkedEncodingError: {e}')
  106. conversation_message_task.end()
  107. return
  108. @classmethod
  109. def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str,
  110. inputs: dict,
  111. agent_execute_result: Optional[AgentExecuteResult],
  112. conversation_message_task: ConversationMessageTask,
  113. memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]):
  114. # When no extra pre prompt is specified,
  115. # the output of the agent can be used directly as the main output content without calling LLM again
  116. fake_response = None
  117. if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
  118. and agent_execute_result.strategy not in [PlanningStrategy.ROUTER, PlanningStrategy.REACT_ROUTER]:
  119. fake_response = agent_execute_result.output
  120. # get llm prompt
  121. prompt_messages, stop_words = model_instance.get_prompt(
  122. mode=mode,
  123. pre_prompt=app_model_config.pre_prompt,
  124. inputs=inputs,
  125. query=query,
  126. context=agent_execute_result.output if agent_execute_result else None,
  127. memory=memory
  128. )
  129. cls.recale_llm_max_tokens(
  130. model_instance=model_instance,
  131. prompt_messages=prompt_messages,
  132. )
  133. response = model_instance.run(
  134. messages=prompt_messages,
  135. stop=stop_words,
  136. callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
  137. fake_response=fake_response
  138. )
  139. return response
  140. @classmethod
  141. def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
  142. max_token_limit: int) -> str:
  143. """Get memory messages."""
  144. memory.max_token_limit = max_token_limit
  145. memory_key = memory.memory_variables[0]
  146. external_context = memory.load_memory_variables({})
  147. return external_context[memory_key]
  148. @classmethod
  149. def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,
  150. conversation: Conversation,
  151. **kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory:
  152. # only for calc token in memory
  153. memory_model_instance = ModelFactory.get_text_generation_model_from_model_config(
  154. tenant_id=tenant_id,
  155. model_config=app_model_config.model_dict
  156. )
  157. # use llm config from conversation
  158. memory = ReadOnlyConversationTokenDBBufferSharedMemory(
  159. conversation=conversation,
  160. model_instance=memory_model_instance,
  161. max_token_limit=kwargs.get("max_token_limit", 2048),
  162. memory_key=kwargs.get("memory_key", "chat_history"),
  163. return_messages=kwargs.get("return_messages", True),
  164. input_key=kwargs.get("input_key", "input"),
  165. output_key=kwargs.get("output_key", "output"),
  166. message_limit=kwargs.get("message_limit", 10),
  167. )
  168. return memory
  169. @classmethod
  170. def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig,
  171. query: str, inputs: dict) -> int:
  172. model_limited_tokens = model_instance.model_rules.max_tokens.max
  173. max_tokens = model_instance.get_model_kwargs().max_tokens
  174. if model_limited_tokens is None:
  175. return -1
  176. if max_tokens is None:
  177. max_tokens = 0
  178. # get prompt without memory and context
  179. prompt_messages, _ = model_instance.get_prompt(
  180. mode=mode,
  181. pre_prompt=app_model_config.pre_prompt,
  182. inputs=inputs,
  183. query=query,
  184. context=None,
  185. memory=None
  186. )
  187. prompt_tokens = model_instance.get_num_tokens(prompt_messages)
  188. rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
  189. if rest_tokens < 0:
  190. raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
  191. "or shrink the max token, or switch to a llm with a larger token limit size.")
  192. return rest_tokens
  193. @classmethod
  194. def recale_llm_max_tokens(cls, model_instance: BaseLLM, prompt_messages: List[PromptMessage]):
  195. # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
  196. model_limited_tokens = model_instance.model_rules.max_tokens.max
  197. max_tokens = model_instance.get_model_kwargs().max_tokens
  198. if model_limited_tokens is None:
  199. return
  200. if max_tokens is None:
  201. max_tokens = 0
  202. prompt_tokens = model_instance.get_num_tokens(prompt_messages)
  203. if prompt_tokens + max_tokens > model_limited_tokens:
  204. max_tokens = max(model_limited_tokens - prompt_tokens, 16)
  205. # update model instance max tokens
  206. model_kwargs = model_instance.get_model_kwargs()
  207. model_kwargs.max_tokens = max_tokens
  208. model_instance.set_model_kwargs(model_kwargs)
  209. @classmethod
  210. def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
  211. app_model_config: AppModelConfig, user: Account, streaming: bool):
  212. final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
  213. tenant_id=app.tenant_id,
  214. model_config=app_model_config.model_dict,
  215. streaming=streaming
  216. )
  217. # get llm prompt
  218. old_prompt_messages, _ = final_model_instance.get_prompt(
  219. mode='completion',
  220. pre_prompt=pre_prompt,
  221. inputs=message.inputs,
  222. query=message.query,
  223. context=None,
  224. memory=None
  225. )
  226. original_completion = message.answer.strip()
  227. prompt = MORE_LIKE_THIS_GENERATE_PROMPT
  228. prompt = prompt.format(prompt=old_prompt_messages[0].content, original_completion=original_completion)
  229. prompt_messages = [PromptMessage(content=prompt)]
  230. conversation_message_task = ConversationMessageTask(
  231. task_id=task_id,
  232. app=app,
  233. app_model_config=app_model_config,
  234. user=user,
  235. inputs=message.inputs,
  236. query=message.query,
  237. is_override=True if message.override_model_configs else False,
  238. streaming=streaming,
  239. model_instance=final_model_instance
  240. )
  241. cls.recale_llm_max_tokens(
  242. model_instance=final_model_instance,
  243. prompt_messages=prompt_messages
  244. )
  245. final_model_instance.run(
  246. messages=prompt_messages,
  247. callbacks=[LLMCallbackHandler(final_model_instance, conversation_message_task)]
  248. )