application_manager.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750
  1. import json
  2. import logging
  3. import threading
  4. import uuid
  5. from collections.abc import Generator
  6. from typing import Any, Optional, Union, cast
  7. from flask import Flask, current_app
  8. from pydantic import ValidationError
  9. from core.app_runner.assistant_app_runner import AssistantApplicationRunner
  10. from core.app_runner.basic_app_runner import BasicApplicationRunner
  11. from core.app_runner.generate_task_pipeline import GenerateTaskPipeline
  12. from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException, PublishFrom
  13. from core.entities.application_entities import (
  14. AdvancedChatPromptTemplateEntity,
  15. AdvancedCompletionPromptTemplateEntity,
  16. AgentEntity,
  17. AgentPromptEntity,
  18. AgentToolEntity,
  19. ApplicationGenerateEntity,
  20. AppOrchestrationConfigEntity,
  21. DatasetEntity,
  22. DatasetRetrieveConfigEntity,
  23. ExternalDataVariableEntity,
  24. FileUploadEntity,
  25. InvokeFrom,
  26. ModelConfigEntity,
  27. PromptTemplateEntity,
  28. SensitiveWordAvoidanceEntity,
  29. )
  30. from core.entities.model_entities import ModelStatus
  31. from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
  32. from core.file.file_obj import FileObj
  33. from core.model_runtime.entities.message_entities import PromptMessageRole
  34. from core.model_runtime.entities.model_entities import ModelType
  35. from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
  36. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  37. from core.prompt.prompt_template import PromptTemplateParser
  38. from core.provider_manager import ProviderManager
  39. from core.tools.prompt.template import REACT_PROMPT_TEMPLATES
  40. from extensions.ext_database import db
  41. from models.account import Account
  42. from models.model import App, Conversation, EndUser, Message, MessageFile
  43. logger = logging.getLogger(__name__)
  44. class ApplicationManager:
  45. """
  46. This class is responsible for managing application
  47. """
  48. def generate(self, tenant_id: str,
  49. app_id: str,
  50. app_model_config_id: str,
  51. app_model_config_dict: dict,
  52. app_model_config_override: bool,
  53. user: Union[Account, EndUser],
  54. invoke_from: InvokeFrom,
  55. inputs: dict[str, str],
  56. query: Optional[str] = None,
  57. files: Optional[list[FileObj]] = None,
  58. conversation: Optional[Conversation] = None,
  59. stream: bool = False,
  60. extras: Optional[dict[str, Any]] = None) \
  61. -> Union[dict, Generator]:
  62. """
  63. Generate App response.
  64. :param tenant_id: workspace ID
  65. :param app_id: app ID
  66. :param app_model_config_id: app model config id
  67. :param app_model_config_dict: app model config dict
  68. :param app_model_config_override: app model config override
  69. :param user: account or end user
  70. :param invoke_from: invoke from source
  71. :param inputs: inputs
  72. :param query: query
  73. :param files: file obj list
  74. :param conversation: conversation
  75. :param stream: is stream
  76. :param extras: extras
  77. """
  78. # init task id
  79. task_id = str(uuid.uuid4())
  80. # init application generate entity
  81. application_generate_entity = ApplicationGenerateEntity(
  82. task_id=task_id,
  83. tenant_id=tenant_id,
  84. app_id=app_id,
  85. app_model_config_id=app_model_config_id,
  86. app_model_config_dict=app_model_config_dict,
  87. app_orchestration_config_entity=self._convert_from_app_model_config_dict(
  88. tenant_id=tenant_id,
  89. app_model_config_dict=app_model_config_dict
  90. ),
  91. app_model_config_override=app_model_config_override,
  92. conversation_id=conversation.id if conversation else None,
  93. inputs=conversation.inputs if conversation else inputs,
  94. query=query.replace('\x00', '') if query else None,
  95. files=files if files else [],
  96. user_id=user.id,
  97. stream=stream,
  98. invoke_from=invoke_from,
  99. extras=extras
  100. )
  101. if not stream and application_generate_entity.app_orchestration_config_entity.agent:
  102. raise ValueError("Agent app is not supported in blocking mode.")
  103. # init generate records
  104. (
  105. conversation,
  106. message
  107. ) = self._init_generate_records(application_generate_entity)
  108. # init queue manager
  109. queue_manager = ApplicationQueueManager(
  110. task_id=application_generate_entity.task_id,
  111. user_id=application_generate_entity.user_id,
  112. invoke_from=application_generate_entity.invoke_from,
  113. conversation_id=conversation.id,
  114. app_mode=conversation.mode,
  115. message_id=message.id
  116. )
  117. # new thread
  118. worker_thread = threading.Thread(target=self._generate_worker, kwargs={
  119. 'flask_app': current_app._get_current_object(),
  120. 'application_generate_entity': application_generate_entity,
  121. 'queue_manager': queue_manager,
  122. 'conversation_id': conversation.id,
  123. 'message_id': message.id,
  124. })
  125. worker_thread.start()
  126. # return response or stream generator
  127. return self._handle_response(
  128. application_generate_entity=application_generate_entity,
  129. queue_manager=queue_manager,
  130. conversation=conversation,
  131. message=message,
  132. stream=stream
  133. )
  134. def _generate_worker(self, flask_app: Flask,
  135. application_generate_entity: ApplicationGenerateEntity,
  136. queue_manager: ApplicationQueueManager,
  137. conversation_id: str,
  138. message_id: str) -> None:
  139. """
  140. Generate worker in a new thread.
  141. :param flask_app: Flask app
  142. :param application_generate_entity: application generate entity
  143. :param queue_manager: queue manager
  144. :param conversation_id: conversation ID
  145. :param message_id: message ID
  146. :return:
  147. """
  148. with flask_app.app_context():
  149. try:
  150. # get conversation and message
  151. conversation = self._get_conversation(conversation_id)
  152. message = self._get_message(message_id)
  153. if application_generate_entity.app_orchestration_config_entity.agent:
  154. # agent app
  155. runner = AssistantApplicationRunner()
  156. runner.run(
  157. application_generate_entity=application_generate_entity,
  158. queue_manager=queue_manager,
  159. conversation=conversation,
  160. message=message
  161. )
  162. else:
  163. # basic app
  164. runner = BasicApplicationRunner()
  165. runner.run(
  166. application_generate_entity=application_generate_entity,
  167. queue_manager=queue_manager,
  168. conversation=conversation,
  169. message=message
  170. )
  171. except ConversationTaskStoppedException:
  172. pass
  173. except InvokeAuthorizationError:
  174. queue_manager.publish_error(
  175. InvokeAuthorizationError('Incorrect API key provided'),
  176. PublishFrom.APPLICATION_MANAGER
  177. )
  178. except ValidationError as e:
  179. logger.exception("Validation Error when generating")
  180. queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
  181. except (ValueError, InvokeError) as e:
  182. queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
  183. except Exception as e:
  184. logger.exception("Unknown Error when generating")
  185. queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
  186. finally:
  187. db.session.remove()
  188. def _handle_response(self, application_generate_entity: ApplicationGenerateEntity,
  189. queue_manager: ApplicationQueueManager,
  190. conversation: Conversation,
  191. message: Message,
  192. stream: bool = False) -> Union[dict, Generator]:
  193. """
  194. Handle response.
  195. :param application_generate_entity: application generate entity
  196. :param queue_manager: queue manager
  197. :param conversation: conversation
  198. :param message: message
  199. :param stream: is stream
  200. :return:
  201. """
  202. # init generate task pipeline
  203. generate_task_pipeline = GenerateTaskPipeline(
  204. application_generate_entity=application_generate_entity,
  205. queue_manager=queue_manager,
  206. conversation=conversation,
  207. message=message
  208. )
  209. try:
  210. return generate_task_pipeline.process(stream=stream)
  211. except ValueError as e:
  212. if e.args[0] == "I/O operation on closed file.": # ignore this error
  213. raise ConversationTaskStoppedException()
  214. else:
  215. logger.exception(e)
  216. raise e
  217. finally:
  218. db.session.remove()
  219. def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \
  220. -> AppOrchestrationConfigEntity:
  221. """
  222. Convert app model config dict to entity.
  223. :param tenant_id: tenant ID
  224. :param app_model_config_dict: app model config dict
  225. :raises ProviderTokenNotInitError: provider token not init error
  226. :return: app orchestration config entity
  227. """
  228. properties = {}
  229. copy_app_model_config_dict = app_model_config_dict.copy()
  230. provider_manager = ProviderManager()
  231. provider_model_bundle = provider_manager.get_provider_model_bundle(
  232. tenant_id=tenant_id,
  233. provider=copy_app_model_config_dict['model']['provider'],
  234. model_type=ModelType.LLM
  235. )
  236. provider_name = provider_model_bundle.configuration.provider.provider
  237. model_name = copy_app_model_config_dict['model']['name']
  238. model_type_instance = provider_model_bundle.model_type_instance
  239. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  240. # check model credentials
  241. model_credentials = provider_model_bundle.configuration.get_current_credentials(
  242. model_type=ModelType.LLM,
  243. model=copy_app_model_config_dict['model']['name']
  244. )
  245. if model_credentials is None:
  246. raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
  247. # check model
  248. provider_model = provider_model_bundle.configuration.get_provider_model(
  249. model=copy_app_model_config_dict['model']['name'],
  250. model_type=ModelType.LLM
  251. )
  252. if provider_model is None:
  253. model_name = copy_app_model_config_dict['model']['name']
  254. raise ValueError(f"Model {model_name} not exist.")
  255. if provider_model.status == ModelStatus.NO_CONFIGURE:
  256. raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
  257. elif provider_model.status == ModelStatus.NO_PERMISSION:
  258. raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
  259. elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
  260. raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
  261. # model config
  262. completion_params = copy_app_model_config_dict['model'].get('completion_params')
  263. stop = []
  264. if 'stop' in completion_params:
  265. stop = completion_params['stop']
  266. del completion_params['stop']
  267. # get model mode
  268. model_mode = copy_app_model_config_dict['model'].get('mode')
  269. if not model_mode:
  270. mode_enum = model_type_instance.get_model_mode(
  271. model=copy_app_model_config_dict['model']['name'],
  272. credentials=model_credentials
  273. )
  274. model_mode = mode_enum.value
  275. model_schema = model_type_instance.get_model_schema(
  276. copy_app_model_config_dict['model']['name'],
  277. model_credentials
  278. )
  279. if not model_schema:
  280. raise ValueError(f"Model {model_name} not exist.")
  281. properties['model_config'] = ModelConfigEntity(
  282. provider=copy_app_model_config_dict['model']['provider'],
  283. model=copy_app_model_config_dict['model']['name'],
  284. model_schema=model_schema,
  285. mode=model_mode,
  286. provider_model_bundle=provider_model_bundle,
  287. credentials=model_credentials,
  288. parameters=completion_params,
  289. stop=stop,
  290. )
  291. # prompt template
  292. prompt_type = PromptTemplateEntity.PromptType.value_of(copy_app_model_config_dict['prompt_type'])
  293. if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
  294. simple_prompt_template = copy_app_model_config_dict.get("pre_prompt", "")
  295. properties['prompt_template'] = PromptTemplateEntity(
  296. prompt_type=prompt_type,
  297. simple_prompt_template=simple_prompt_template
  298. )
  299. else:
  300. advanced_chat_prompt_template = None
  301. chat_prompt_config = copy_app_model_config_dict.get("chat_prompt_config", {})
  302. if chat_prompt_config:
  303. chat_prompt_messages = []
  304. for message in chat_prompt_config.get("prompt", []):
  305. chat_prompt_messages.append({
  306. "text": message["text"],
  307. "role": PromptMessageRole.value_of(message["role"])
  308. })
  309. advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(
  310. messages=chat_prompt_messages
  311. )
  312. advanced_completion_prompt_template = None
  313. completion_prompt_config = copy_app_model_config_dict.get("completion_prompt_config", {})
  314. if completion_prompt_config:
  315. completion_prompt_template_params = {
  316. 'prompt': completion_prompt_config['prompt']['text'],
  317. }
  318. if 'conversation_histories_role' in completion_prompt_config:
  319. completion_prompt_template_params['role_prefix'] = {
  320. 'user': completion_prompt_config['conversation_histories_role']['user_prefix'],
  321. 'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix']
  322. }
  323. advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
  324. **completion_prompt_template_params
  325. )
  326. properties['prompt_template'] = PromptTemplateEntity(
  327. prompt_type=prompt_type,
  328. advanced_chat_prompt_template=advanced_chat_prompt_template,
  329. advanced_completion_prompt_template=advanced_completion_prompt_template
  330. )
  331. # external data variables
  332. properties['external_data_variables'] = []
  333. # old external_data_tools
  334. external_data_tools = copy_app_model_config_dict.get('external_data_tools', [])
  335. for external_data_tool in external_data_tools:
  336. if 'enabled' not in external_data_tool or not external_data_tool['enabled']:
  337. continue
  338. properties['external_data_variables'].append(
  339. ExternalDataVariableEntity(
  340. variable=external_data_tool['variable'],
  341. type=external_data_tool['type'],
  342. config=external_data_tool['config']
  343. )
  344. )
  345. # current external_data_tools
  346. for variable in copy_app_model_config_dict.get('user_input_form', []):
  347. typ = list(variable.keys())[0]
  348. if typ == 'external_data_tool':
  349. val = variable[typ]
  350. properties['external_data_variables'].append(
  351. ExternalDataVariableEntity(
  352. variable=val['variable'],
  353. type=val['type'],
  354. config=val['config']
  355. )
  356. )
  357. # show retrieve source
  358. show_retrieve_source = False
  359. retriever_resource_dict = copy_app_model_config_dict.get('retriever_resource')
  360. if retriever_resource_dict:
  361. if 'enabled' in retriever_resource_dict and retriever_resource_dict['enabled']:
  362. show_retrieve_source = True
  363. properties['show_retrieve_source'] = show_retrieve_source
  364. dataset_ids = []
  365. if 'datasets' in copy_app_model_config_dict.get('dataset_configs', {}):
  366. datasets = copy_app_model_config_dict.get('dataset_configs', {}).get('datasets', {
  367. 'strategy': 'router',
  368. 'datasets': []
  369. })
  370. for dataset in datasets.get('datasets', []):
  371. keys = list(dataset.keys())
  372. if len(keys) == 0 or keys[0] != 'dataset':
  373. continue
  374. dataset = dataset['dataset']
  375. if 'enabled' not in dataset or not dataset['enabled']:
  376. continue
  377. dataset_id = dataset.get('id', None)
  378. if dataset_id:
  379. dataset_ids.append(dataset_id)
  380. else:
  381. datasets = {'strategy': 'router', 'datasets': []}
  382. if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \
  383. and 'enabled' in copy_app_model_config_dict['agent_mode'] \
  384. and copy_app_model_config_dict['agent_mode']['enabled']:
  385. agent_dict = copy_app_model_config_dict.get('agent_mode', {})
  386. agent_strategy = agent_dict.get('strategy', 'cot')
  387. if agent_strategy == 'function_call':
  388. strategy = AgentEntity.Strategy.FUNCTION_CALLING
  389. elif agent_strategy == 'cot' or agent_strategy == 'react':
  390. strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
  391. else:
  392. # old configs, try to detect default strategy
  393. if copy_app_model_config_dict['model']['provider'] == 'openai':
  394. strategy = AgentEntity.Strategy.FUNCTION_CALLING
  395. else:
  396. strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
  397. agent_tools = []
  398. for tool in agent_dict.get('tools', []):
  399. keys = tool.keys()
  400. if len(keys) >= 4:
  401. if "enabled" not in tool or not tool["enabled"]:
  402. continue
  403. agent_tool_properties = {
  404. 'provider_type': tool['provider_type'],
  405. 'provider_id': tool['provider_id'],
  406. 'tool_name': tool['tool_name'],
  407. 'tool_parameters': tool['tool_parameters'] if 'tool_parameters' in tool else {}
  408. }
  409. agent_tools.append(AgentToolEntity(**agent_tool_properties))
  410. elif len(keys) == 1:
  411. # old standard
  412. key = list(tool.keys())[0]
  413. if key != 'dataset':
  414. continue
  415. tool_item = tool[key]
  416. if "enabled" not in tool_item or not tool_item["enabled"]:
  417. continue
  418. dataset_id = tool_item['id']
  419. dataset_ids.append(dataset_id)
  420. if 'strategy' in copy_app_model_config_dict['agent_mode'] and \
  421. copy_app_model_config_dict['agent_mode']['strategy'] not in ['react_router', 'router']:
  422. agent_prompt = agent_dict.get('prompt', None) or {}
  423. # check model mode
  424. model_mode = copy_app_model_config_dict.get('model', {}).get('mode', 'completion')
  425. if model_mode == 'completion':
  426. agent_prompt_entity = AgentPromptEntity(
  427. first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['completion']['prompt']),
  428. next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['completion']['agent_scratchpad']),
  429. )
  430. else:
  431. agent_prompt_entity = AgentPromptEntity(
  432. first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['chat']['prompt']),
  433. next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']),
  434. )
  435. properties['agent'] = AgentEntity(
  436. provider=properties['model_config'].provider,
  437. model=properties['model_config'].model,
  438. strategy=strategy,
  439. prompt=agent_prompt_entity,
  440. tools=agent_tools,
  441. max_iteration=agent_dict.get('max_iteration', 5)
  442. )
  443. if len(dataset_ids) > 0:
  444. # dataset configs
  445. dataset_configs = copy_app_model_config_dict.get('dataset_configs', {'retrieval_model': 'single'})
  446. query_variable = copy_app_model_config_dict.get('dataset_query_variable')
  447. if dataset_configs['retrieval_model'] == 'single':
  448. properties['dataset'] = DatasetEntity(
  449. dataset_ids=dataset_ids,
  450. retrieve_config=DatasetRetrieveConfigEntity(
  451. query_variable=query_variable,
  452. retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
  453. dataset_configs['retrieval_model']
  454. ),
  455. single_strategy=datasets.get('strategy', 'router')
  456. )
  457. )
  458. else:
  459. properties['dataset'] = DatasetEntity(
  460. dataset_ids=dataset_ids,
  461. retrieve_config=DatasetRetrieveConfigEntity(
  462. query_variable=query_variable,
  463. retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
  464. dataset_configs['retrieval_model']
  465. ),
  466. top_k=dataset_configs.get('top_k'),
  467. score_threshold=dataset_configs.get('score_threshold'),
  468. reranking_model=dataset_configs.get('reranking_model')
  469. )
  470. )
  471. # file upload
  472. file_upload_dict = copy_app_model_config_dict.get('file_upload')
  473. if file_upload_dict:
  474. if 'image' in file_upload_dict and file_upload_dict['image']:
  475. if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']:
  476. properties['file_upload'] = FileUploadEntity(
  477. image_config={
  478. 'number_limits': file_upload_dict['image']['number_limits'],
  479. 'detail': file_upload_dict['image']['detail'],
  480. 'transfer_methods': file_upload_dict['image']['transfer_methods']
  481. }
  482. )
  483. # opening statement
  484. properties['opening_statement'] = copy_app_model_config_dict.get('opening_statement')
  485. # suggested questions after answer
  486. suggested_questions_after_answer_dict = copy_app_model_config_dict.get('suggested_questions_after_answer')
  487. if suggested_questions_after_answer_dict:
  488. if 'enabled' in suggested_questions_after_answer_dict and suggested_questions_after_answer_dict['enabled']:
  489. properties['suggested_questions_after_answer'] = True
  490. # more like this
  491. more_like_this_dict = copy_app_model_config_dict.get('more_like_this')
  492. if more_like_this_dict:
  493. if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']:
  494. properties['more_like_this'] = True
  495. # speech to text
  496. speech_to_text_dict = copy_app_model_config_dict.get('speech_to_text')
  497. if speech_to_text_dict:
  498. if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']:
  499. properties['speech_to_text'] = True
  500. # text to speech
  501. text_to_speech_dict = copy_app_model_config_dict.get('text_to_speech')
  502. if text_to_speech_dict:
  503. if 'enabled' in text_to_speech_dict and text_to_speech_dict['enabled']:
  504. properties['text_to_speech'] = True
  505. # sensitive word avoidance
  506. sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance')
  507. if sensitive_word_avoidance_dict:
  508. if 'enabled' in sensitive_word_avoidance_dict and sensitive_word_avoidance_dict['enabled']:
  509. properties['sensitive_word_avoidance'] = SensitiveWordAvoidanceEntity(
  510. type=sensitive_word_avoidance_dict.get('type'),
  511. config=sensitive_word_avoidance_dict.get('config'),
  512. )
  513. return AppOrchestrationConfigEntity(**properties)
  514. def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \
  515. -> tuple[Conversation, Message]:
  516. """
  517. Initialize generate records
  518. :param application_generate_entity: application generate entity
  519. :return:
  520. """
  521. app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
  522. model_type_instance = app_orchestration_config_entity.model_config.provider_model_bundle.model_type_instance
  523. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  524. model_schema = model_type_instance.get_model_schema(
  525. model=app_orchestration_config_entity.model_config.model,
  526. credentials=app_orchestration_config_entity.model_config.credentials
  527. )
  528. app_record = (db.session.query(App)
  529. .filter(App.id == application_generate_entity.app_id).first())
  530. app_mode = app_record.mode
  531. # get from source
  532. end_user_id = None
  533. account_id = None
  534. if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
  535. from_source = 'api'
  536. end_user_id = application_generate_entity.user_id
  537. else:
  538. from_source = 'console'
  539. account_id = application_generate_entity.user_id
  540. override_model_configs = None
  541. if application_generate_entity.app_model_config_override:
  542. override_model_configs = application_generate_entity.app_model_config_dict
  543. introduction = ''
  544. if app_mode == 'chat':
  545. # get conversation introduction
  546. introduction = self._get_conversation_introduction(application_generate_entity)
  547. if not application_generate_entity.conversation_id:
  548. conversation = Conversation(
  549. app_id=app_record.id,
  550. app_model_config_id=application_generate_entity.app_model_config_id,
  551. model_provider=app_orchestration_config_entity.model_config.provider,
  552. model_id=app_orchestration_config_entity.model_config.model,
  553. override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
  554. mode=app_mode,
  555. name='New conversation',
  556. inputs=application_generate_entity.inputs,
  557. introduction=introduction,
  558. system_instruction="",
  559. system_instruction_tokens=0,
  560. status='normal',
  561. from_source=from_source,
  562. from_end_user_id=end_user_id,
  563. from_account_id=account_id,
  564. )
  565. db.session.add(conversation)
  566. db.session.commit()
  567. else:
  568. conversation = (
  569. db.session.query(Conversation)
  570. .filter(
  571. Conversation.id == application_generate_entity.conversation_id,
  572. Conversation.app_id == app_record.id
  573. ).first()
  574. )
  575. currency = model_schema.pricing.currency if model_schema.pricing else 'USD'
  576. message = Message(
  577. app_id=app_record.id,
  578. model_provider=app_orchestration_config_entity.model_config.provider,
  579. model_id=app_orchestration_config_entity.model_config.model,
  580. override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
  581. conversation_id=conversation.id,
  582. inputs=application_generate_entity.inputs,
  583. query=application_generate_entity.query or "",
  584. message="",
  585. message_tokens=0,
  586. message_unit_price=0,
  587. message_price_unit=0,
  588. answer="",
  589. answer_tokens=0,
  590. answer_unit_price=0,
  591. answer_price_unit=0,
  592. provider_response_latency=0,
  593. total_price=0,
  594. currency=currency,
  595. from_source=from_source,
  596. from_end_user_id=end_user_id,
  597. from_account_id=account_id,
  598. agent_based=app_orchestration_config_entity.agent is not None
  599. )
  600. db.session.add(message)
  601. db.session.commit()
  602. for file in application_generate_entity.files:
  603. message_file = MessageFile(
  604. message_id=message.id,
  605. type=file.type.value,
  606. transfer_method=file.transfer_method.value,
  607. belongs_to='user',
  608. url=file.url,
  609. upload_file_id=file.upload_file_id,
  610. created_by_role=('account' if account_id else 'end_user'),
  611. created_by=account_id or end_user_id,
  612. )
  613. db.session.add(message_file)
  614. db.session.commit()
  615. return conversation, message
  616. def _get_conversation_introduction(self, application_generate_entity: ApplicationGenerateEntity) -> str:
  617. """
  618. Get conversation introduction
  619. :param application_generate_entity: application generate entity
  620. :return: conversation introduction
  621. """
  622. app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
  623. introduction = app_orchestration_config_entity.opening_statement
  624. if introduction:
  625. try:
  626. inputs = application_generate_entity.inputs
  627. prompt_template = PromptTemplateParser(template=introduction)
  628. prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
  629. introduction = prompt_template.format(prompt_inputs)
  630. except KeyError:
  631. pass
  632. return introduction
  633. def _get_conversation(self, conversation_id: str) -> Conversation:
  634. """
  635. Get conversation by conversation id
  636. :param conversation_id: conversation id
  637. :return: conversation
  638. """
  639. conversation = (
  640. db.session.query(Conversation)
  641. .filter(Conversation.id == conversation_id)
  642. .first()
  643. )
  644. return conversation
  645. def _get_message(self, message_id: str) -> Message:
  646. """
  647. Get message by message id
  648. :param message_id: message id
  649. :return: message
  650. """
  651. message = (
  652. db.session.query(Message)
  653. .filter(Message.id == message_id)
  654. .first()
  655. )
  656. return message