completion.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. import logging
  2. from typing import Optional, List, Union
  3. from requests.exceptions import ChunkedEncodingError
  4. from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
  5. from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
  6. from core.callback_handler.llm_callback_handler import LLMCallbackHandler
  7. from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceError
  8. from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
  9. from core.model_providers.error import LLMBadRequestError
  10. from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
  11. ReadOnlyConversationTokenDBBufferSharedMemory
  12. from core.model_providers.model_factory import ModelFactory
  13. from core.model_providers.models.entity.message import PromptMessage
  14. from core.model_providers.models.llm.base import BaseLLM
  15. from core.orchestrator_rule_parser import OrchestratorRuleParser
  16. from core.prompt.prompt_template import PromptTemplateParser
  17. from core.prompt.prompt_transform import PromptTransform
  18. from models.model import App, AppModelConfig, Account, Conversation, EndUser
  19. class Completion:
  20. @classmethod
  21. def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
  22. user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool,
  23. is_override: bool = False, retriever_from: str = 'dev'):
  24. """
  25. errors: ProviderTokenNotInitError
  26. """
  27. query = PromptTemplateParser.remove_template_variables(query)
  28. memory = None
  29. if conversation:
  30. # get memory of conversation (read-only)
  31. memory = cls.get_memory_from_conversation(
  32. tenant_id=app.tenant_id,
  33. app_model_config=app_model_config,
  34. conversation=conversation,
  35. return_messages=False
  36. )
  37. inputs = conversation.inputs
  38. final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
  39. tenant_id=app.tenant_id,
  40. model_config=app_model_config.model_dict,
  41. streaming=streaming
  42. )
  43. conversation_message_task = ConversationMessageTask(
  44. task_id=task_id,
  45. app=app,
  46. app_model_config=app_model_config,
  47. user=user,
  48. conversation=conversation,
  49. is_override=is_override,
  50. inputs=inputs,
  51. query=query,
  52. streaming=streaming,
  53. model_instance=final_model_instance
  54. )
  55. rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
  56. mode=app.mode,
  57. model_instance=final_model_instance,
  58. app_model_config=app_model_config,
  59. query=query,
  60. inputs=inputs
  61. )
  62. # init orchestrator rule parser
  63. orchestrator_rule_parser = OrchestratorRuleParser(
  64. tenant_id=app.tenant_id,
  65. app_model_config=app_model_config
  66. )
  67. try:
  68. # parse sensitive_word_avoidance_chain
  69. chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
  70. sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain(
  71. final_model_instance, [chain_callback])
  72. if sensitive_word_avoidance_chain:
  73. try:
  74. query = sensitive_word_avoidance_chain.run(query)
  75. except SensitiveWordAvoidanceError as ex:
  76. cls.run_final_llm(
  77. model_instance=final_model_instance,
  78. mode=app.mode,
  79. app_model_config=app_model_config,
  80. query=query,
  81. inputs=inputs,
  82. agent_execute_result=None,
  83. conversation_message_task=conversation_message_task,
  84. memory=memory,
  85. fake_response=ex.message
  86. )
  87. return
  88. # get agent executor
  89. agent_executor = orchestrator_rule_parser.to_agent_executor(
  90. conversation_message_task=conversation_message_task,
  91. memory=memory,
  92. rest_tokens=rest_tokens_for_context_and_memory,
  93. chain_callback=chain_callback,
  94. retriever_from=retriever_from
  95. )
  96. query_for_agent = cls.get_query_for_agent(app, app_model_config, query, inputs)
  97. # run agent executor
  98. agent_execute_result = None
  99. if query_for_agent and agent_executor:
  100. should_use_agent = agent_executor.should_use_agent(query_for_agent)
  101. if should_use_agent:
  102. agent_execute_result = agent_executor.run(query_for_agent)
  103. # When no extra pre prompt is specified,
  104. # the output of the agent can be used directly as the main output content without calling LLM again
  105. fake_response = None
  106. if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
  107. and agent_execute_result.strategy not in [PlanningStrategy.ROUTER,
  108. PlanningStrategy.REACT_ROUTER]:
  109. fake_response = agent_execute_result.output
  110. # run the final llm
  111. cls.run_final_llm(
  112. model_instance=final_model_instance,
  113. mode=app.mode,
  114. app_model_config=app_model_config,
  115. query=query,
  116. inputs=inputs,
  117. agent_execute_result=agent_execute_result,
  118. conversation_message_task=conversation_message_task,
  119. memory=memory,
  120. fake_response=fake_response
  121. )
  122. except ConversationTaskStoppedException:
  123. return
  124. except ChunkedEncodingError as e:
  125. # Interrupt by LLM (like OpenAI), handle it.
  126. logging.warning(f'ChunkedEncodingError: {e}')
  127. conversation_message_task.end()
  128. return
  129. @classmethod
  130. def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str:
  131. if app.mode != 'completion':
  132. return query
  133. return inputs.get(app_model_config.dataset_query_variable, "")
  134. @classmethod
  135. def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str,
  136. inputs: dict,
  137. agent_execute_result: Optional[AgentExecuteResult],
  138. conversation_message_task: ConversationMessageTask,
  139. memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
  140. fake_response: Optional[str]):
  141. prompt_transform = PromptTransform()
  142. # get llm prompt
  143. if app_model_config.prompt_type == 'simple':
  144. prompt_messages, stop_words = prompt_transform.get_prompt(
  145. mode=mode,
  146. pre_prompt=app_model_config.pre_prompt,
  147. inputs=inputs,
  148. query=query,
  149. context=agent_execute_result.output if agent_execute_result else None,
  150. memory=memory,
  151. model_instance=model_instance
  152. )
  153. else:
  154. prompt_messages = prompt_transform.get_advanced_prompt(
  155. app_mode=mode,
  156. app_model_config=app_model_config,
  157. inputs=inputs,
  158. query=query,
  159. context=agent_execute_result.output if agent_execute_result else None,
  160. memory=memory,
  161. model_instance=model_instance
  162. )
  163. model_config = app_model_config.model_dict
  164. completion_params = model_config.get("completion_params", {})
  165. stop_words = completion_params.get("stop", [])
  166. cls.recale_llm_max_tokens(
  167. model_instance=model_instance,
  168. prompt_messages=prompt_messages,
  169. )
  170. response = model_instance.run(
  171. messages=prompt_messages,
  172. stop=stop_words if stop_words else None,
  173. callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
  174. fake_response=fake_response
  175. )
  176. return response
  177. @classmethod
  178. def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
  179. max_token_limit: int) -> str:
  180. """Get memory messages."""
  181. memory.max_token_limit = max_token_limit
  182. memory_key = memory.memory_variables[0]
  183. external_context = memory.load_memory_variables({})
  184. return external_context[memory_key]
  185. @classmethod
  186. def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,
  187. conversation: Conversation,
  188. **kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory:
  189. # only for calc token in memory
  190. memory_model_instance = ModelFactory.get_text_generation_model_from_model_config(
  191. tenant_id=tenant_id,
  192. model_config=app_model_config.model_dict
  193. )
  194. # use llm config from conversation
  195. memory = ReadOnlyConversationTokenDBBufferSharedMemory(
  196. conversation=conversation,
  197. model_instance=memory_model_instance,
  198. max_token_limit=kwargs.get("max_token_limit", 2048),
  199. memory_key=kwargs.get("memory_key", "chat_history"),
  200. return_messages=kwargs.get("return_messages", True),
  201. input_key=kwargs.get("input_key", "input"),
  202. output_key=kwargs.get("output_key", "output"),
  203. message_limit=kwargs.get("message_limit", 10),
  204. )
  205. return memory
  206. @classmethod
  207. def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig,
  208. query: str, inputs: dict) -> int:
  209. model_limited_tokens = model_instance.model_rules.max_tokens.max
  210. max_tokens = model_instance.get_model_kwargs().max_tokens
  211. if model_limited_tokens is None:
  212. return -1
  213. if max_tokens is None:
  214. max_tokens = 0
  215. prompt_transform = PromptTransform()
  216. prompt_messages = []
  217. # get prompt without memory and context
  218. if app_model_config.prompt_type == 'simple':
  219. prompt_messages, _ = prompt_transform.get_prompt(
  220. mode=mode,
  221. pre_prompt=app_model_config.pre_prompt,
  222. inputs=inputs,
  223. query=query,
  224. context=None,
  225. memory=None,
  226. model_instance=model_instance
  227. )
  228. else:
  229. prompt_messages = prompt_transform.get_advanced_prompt(
  230. app_mode=mode,
  231. app_model_config=app_model_config,
  232. inputs=inputs,
  233. query=query,
  234. context=None,
  235. memory=None,
  236. model_instance=model_instance
  237. )
  238. prompt_tokens = model_instance.get_num_tokens(prompt_messages)
  239. rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
  240. if rest_tokens < 0:
  241. raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
  242. "or shrink the max token, or switch to a llm with a larger token limit size.")
  243. return rest_tokens
  244. @classmethod
  245. def recale_llm_max_tokens(cls, model_instance: BaseLLM, prompt_messages: List[PromptMessage]):
  246. # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
  247. model_limited_tokens = model_instance.model_rules.max_tokens.max
  248. max_tokens = model_instance.get_model_kwargs().max_tokens
  249. if model_limited_tokens is None:
  250. return
  251. if max_tokens is None:
  252. max_tokens = 0
  253. prompt_tokens = model_instance.get_num_tokens(prompt_messages)
  254. if prompt_tokens + max_tokens > model_limited_tokens:
  255. max_tokens = max(model_limited_tokens - prompt_tokens, 16)
  256. # update model instance max tokens
  257. model_kwargs = model_instance.get_model_kwargs()
  258. model_kwargs.max_tokens = max_tokens
  259. model_instance.set_model_kwargs(model_kwargs)