|
@@ -0,0 +1,655 @@
|
|
|
+import json
|
|
|
+import logging
|
|
|
+import threading
|
|
|
+import uuid
|
|
|
+from typing import cast, Optional, Any, Union, Generator, Tuple
|
|
|
+
|
|
|
+from flask import Flask, current_app
|
|
|
+from pydantic import ValidationError
|
|
|
+
|
|
|
+from core.app_runner.agent_app_runner import AgentApplicationRunner
|
|
|
+from core.app_runner.basic_app_runner import BasicApplicationRunner
|
|
|
+from core.app_runner.generate_task_pipeline import GenerateTaskPipeline
|
|
|
+from core.entities.application_entities import ApplicationGenerateEntity, AppOrchestrationConfigEntity, \
|
|
|
+ ModelConfigEntity, PromptTemplateEntity, AdvancedChatPromptTemplateEntity, \
|
|
|
+ AdvancedCompletionPromptTemplateEntity, ExternalDataVariableEntity, DatasetEntity, DatasetRetrieveConfigEntity, \
|
|
|
+ AgentEntity, AgentToolEntity, FileUploadEntity, SensitiveWordAvoidanceEntity, InvokeFrom
|
|
|
+from core.entities.model_entities import ModelStatus
|
|
|
+from core.file.file_obj import FileObj
|
|
|
+from core.errors.error import QuotaExceededError, ProviderTokenNotInitError, ModelCurrentlyNotSupportError
|
|
|
+from core.model_runtime.entities.message_entities import PromptMessageRole
|
|
|
+from core.model_runtime.entities.model_entities import ModelType
|
|
|
+from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
|
|
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
|
+from core.prompt.prompt_template import PromptTemplateParser
|
|
|
+from core.provider_manager import ProviderManager
|
|
|
+from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException
|
|
|
+from extensions.ext_database import db
|
|
|
+from models.account import Account
|
|
|
+from models.model import EndUser, Conversation, Message, MessageFile, App
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+
|
|
|
+class ApplicationManager:
|
|
|
+ """
|
|
|
+ This class is responsible for managing application
|
|
|
+ """
|
|
|
+
|
|
|
+ def generate(self, tenant_id: str,
|
|
|
+ app_id: str,
|
|
|
+ app_model_config_id: str,
|
|
|
+ app_model_config_dict: dict,
|
|
|
+ app_model_config_override: bool,
|
|
|
+ user: Union[Account, EndUser],
|
|
|
+ invoke_from: InvokeFrom,
|
|
|
+ inputs: dict[str, str],
|
|
|
+ query: Optional[str] = None,
|
|
|
+ files: Optional[list[FileObj]] = None,
|
|
|
+ conversation: Optional[Conversation] = None,
|
|
|
+ stream: bool = False,
|
|
|
+ extras: Optional[dict[str, Any]] = None) \
|
|
|
+ -> Union[dict, Generator]:
|
|
|
+ """
|
|
|
+ Generate App response.
|
|
|
+
|
|
|
+ :param tenant_id: workspace ID
|
|
|
+ :param app_id: app ID
|
|
|
+ :param app_model_config_id: app model config id
|
|
|
+ :param app_model_config_dict: app model config dict
|
|
|
+ :param app_model_config_override: app model config override
|
|
|
+ :param user: account or end user
|
|
|
+ :param invoke_from: invoke from source
|
|
|
+ :param inputs: inputs
|
|
|
+ :param query: query
|
|
|
+ :param files: file obj list
|
|
|
+ :param conversation: conversation
|
|
|
+ :param stream: is stream
|
|
|
+ :param extras: extras
|
|
|
+ """
|
|
|
+ # init task id
|
|
|
+ task_id = str(uuid.uuid4())
|
|
|
+
|
|
|
+ # init application generate entity
|
|
|
+ application_generate_entity = ApplicationGenerateEntity(
|
|
|
+ task_id=task_id,
|
|
|
+ tenant_id=tenant_id,
|
|
|
+ app_id=app_id,
|
|
|
+ app_model_config_id=app_model_config_id,
|
|
|
+ app_model_config_dict=app_model_config_dict,
|
|
|
+ app_orchestration_config_entity=self._convert_from_app_model_config_dict(
|
|
|
+ tenant_id=tenant_id,
|
|
|
+ app_model_config_dict=app_model_config_dict
|
|
|
+ ),
|
|
|
+ app_model_config_override=app_model_config_override,
|
|
|
+ conversation_id=conversation.id if conversation else None,
|
|
|
+ inputs=conversation.inputs if conversation else inputs,
|
|
|
+ query=query.replace('\x00', '') if query else None,
|
|
|
+ files=files if files else [],
|
|
|
+ user_id=user.id,
|
|
|
+ stream=stream,
|
|
|
+ invoke_from=invoke_from,
|
|
|
+ extras=extras
|
|
|
+ )
|
|
|
+
|
|
|
+ # init generate records
|
|
|
+ (
|
|
|
+ conversation,
|
|
|
+ message
|
|
|
+ ) = self._init_generate_records(application_generate_entity)
|
|
|
+
|
|
|
+ # init queue manager
|
|
|
+ queue_manager = ApplicationQueueManager(
|
|
|
+ task_id=application_generate_entity.task_id,
|
|
|
+ user_id=application_generate_entity.user_id,
|
|
|
+ invoke_from=application_generate_entity.invoke_from,
|
|
|
+ conversation_id=conversation.id,
|
|
|
+ app_mode=conversation.mode,
|
|
|
+ message_id=message.id
|
|
|
+ )
|
|
|
+
|
|
|
+ # new thread
|
|
|
+ worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
|
|
+ 'flask_app': current_app._get_current_object(),
|
|
|
+ 'application_generate_entity': application_generate_entity,
|
|
|
+ 'queue_manager': queue_manager,
|
|
|
+ 'conversation_id': conversation.id,
|
|
|
+ 'message_id': message.id,
|
|
|
+ })
|
|
|
+
|
|
|
+ worker_thread.start()
|
|
|
+
|
|
|
+ # return response or stream generator
|
|
|
+ return self._handle_response(
|
|
|
+ application_generate_entity=application_generate_entity,
|
|
|
+ queue_manager=queue_manager,
|
|
|
+ conversation=conversation,
|
|
|
+ message=message,
|
|
|
+ stream=stream
|
|
|
+ )
|
|
|
+
|
|
|
+ def _generate_worker(self, flask_app: Flask,
|
|
|
+ application_generate_entity: ApplicationGenerateEntity,
|
|
|
+ queue_manager: ApplicationQueueManager,
|
|
|
+ conversation_id: str,
|
|
|
+ message_id: str) -> None:
|
|
|
+ """
|
|
|
+ Generate worker in a new thread.
|
|
|
+ :param flask_app: Flask app
|
|
|
+ :param application_generate_entity: application generate entity
|
|
|
+ :param queue_manager: queue manager
|
|
|
+ :param conversation_id: conversation ID
|
|
|
+ :param message_id: message ID
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ with flask_app.app_context():
|
|
|
+ try:
|
|
|
+ # get conversation and message
|
|
|
+ conversation = self._get_conversation(conversation_id)
|
|
|
+ message = self._get_message(message_id)
|
|
|
+
|
|
|
+ if application_generate_entity.app_orchestration_config_entity.agent:
|
|
|
+ # agent app
|
|
|
+ runner = AgentApplicationRunner()
|
|
|
+ runner.run(
|
|
|
+ application_generate_entity=application_generate_entity,
|
|
|
+ queue_manager=queue_manager,
|
|
|
+ conversation=conversation,
|
|
|
+ message=message
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ # basic app
|
|
|
+ runner = BasicApplicationRunner()
|
|
|
+ runner.run(
|
|
|
+ application_generate_entity=application_generate_entity,
|
|
|
+ queue_manager=queue_manager,
|
|
|
+ conversation=conversation,
|
|
|
+ message=message
|
|
|
+ )
|
|
|
+ except ConversationTaskStoppedException:
|
|
|
+ pass
|
|
|
+ except InvokeAuthorizationError:
|
|
|
+ queue_manager.publish_error(InvokeAuthorizationError('Incorrect API key provided'))
|
|
|
+ except ValidationError as e:
|
|
|
+ logger.exception("Validation Error when generating")
|
|
|
+ queue_manager.publish_error(e)
|
|
|
+ except (ValueError, InvokeError) as e:
|
|
|
+ queue_manager.publish_error(e)
|
|
|
+ except Exception as e:
|
|
|
+ logger.exception("Unknown Error when generating")
|
|
|
+ queue_manager.publish_error(e)
|
|
|
+ finally:
|
|
|
+ db.session.remove()
|
|
|
+
|
|
|
+ def _handle_response(self, application_generate_entity: ApplicationGenerateEntity,
|
|
|
+ queue_manager: ApplicationQueueManager,
|
|
|
+ conversation: Conversation,
|
|
|
+ message: Message,
|
|
|
+ stream: bool = False) -> Union[dict, Generator]:
|
|
|
+ """
|
|
|
+ Handle response.
|
|
|
+ :param application_generate_entity: application generate entity
|
|
|
+ :param queue_manager: queue manager
|
|
|
+ :param conversation: conversation
|
|
|
+ :param message: message
|
|
|
+ :param stream: is stream
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ # init generate task pipeline
|
|
|
+ generate_task_pipeline = GenerateTaskPipeline(
|
|
|
+ application_generate_entity=application_generate_entity,
|
|
|
+ queue_manager=queue_manager,
|
|
|
+ conversation=conversation,
|
|
|
+ message=message
|
|
|
+ )
|
|
|
+
|
|
|
+ try:
|
|
|
+ return generate_task_pipeline.process(stream=stream)
|
|
|
+ except ValueError as e:
|
|
|
+ if e.args[0] == "I/O operation on closed file.": # ignore this error
|
|
|
+ raise ConversationTaskStoppedException()
|
|
|
+ else:
|
|
|
+ logger.exception(e)
|
|
|
+ raise e
|
|
|
+ finally:
|
|
|
+ db.session.remove()
|
|
|
+
|
|
|
+ def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \
|
|
|
+ -> AppOrchestrationConfigEntity:
|
|
|
+ """
|
|
|
+ Convert app model config dict to entity.
|
|
|
+ :param tenant_id: tenant ID
|
|
|
+ :param app_model_config_dict: app model config dict
|
|
|
+ :raises ProviderTokenNotInitError: provider token not init error
|
|
|
+ :return: app orchestration config entity
|
|
|
+ """
|
|
|
+ properties = {}
|
|
|
+
|
|
|
+ copy_app_model_config_dict = app_model_config_dict.copy()
|
|
|
+
|
|
|
+ provider_manager = ProviderManager()
|
|
|
+ provider_model_bundle = provider_manager.get_provider_model_bundle(
|
|
|
+ tenant_id=tenant_id,
|
|
|
+ provider=copy_app_model_config_dict['model']['provider'],
|
|
|
+ model_type=ModelType.LLM
|
|
|
+ )
|
|
|
+
|
|
|
+ provider_name = provider_model_bundle.configuration.provider.provider
|
|
|
+ model_name = copy_app_model_config_dict['model']['name']
|
|
|
+
|
|
|
+ model_type_instance = provider_model_bundle.model_type_instance
|
|
|
+ model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
|
|
+
|
|
|
+ # check model credentials
|
|
|
+ model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
|
|
+ model_type=ModelType.LLM,
|
|
|
+ model=copy_app_model_config_dict['model']['name']
|
|
|
+ )
|
|
|
+
|
|
|
+ if model_credentials is None:
|
|
|
+ raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
|
|
+
|
|
|
+ # check model
|
|
|
+ provider_model = provider_model_bundle.configuration.get_provider_model(
|
|
|
+ model=copy_app_model_config_dict['model']['name'],
|
|
|
+ model_type=ModelType.LLM
|
|
|
+ )
|
|
|
+
|
|
|
+ if provider_model is None:
|
|
|
+ model_name = copy_app_model_config_dict['model']['name']
|
|
|
+ raise ValueError(f"Model {model_name} not exist.")
|
|
|
+
|
|
|
+ if provider_model.status == ModelStatus.NO_CONFIGURE:
|
|
|
+ raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
|
|
+ elif provider_model.status == ModelStatus.NO_PERMISSION:
|
|
|
+ raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
|
|
+ elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
|
|
+ raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
|
|
+
|
|
|
+ # model config
|
|
|
+ completion_params = copy_app_model_config_dict['model'].get('completion_params')
|
|
|
+ stop = []
|
|
|
+ if 'stop' in completion_params:
|
|
|
+ stop = completion_params['stop']
|
|
|
+ del completion_params['stop']
|
|
|
+
|
|
|
+ # get model mode
|
|
|
+ model_mode = copy_app_model_config_dict['model'].get('mode')
|
|
|
+ if not model_mode:
|
|
|
+ mode_enum = model_type_instance.get_model_mode(
|
|
|
+ model=copy_app_model_config_dict['model']['name'],
|
|
|
+ credentials=model_credentials
|
|
|
+ )
|
|
|
+
|
|
|
+ model_mode = mode_enum.value
|
|
|
+
|
|
|
+ model_schema = model_type_instance.get_model_schema(
|
|
|
+ copy_app_model_config_dict['model']['name'],
|
|
|
+ model_credentials
|
|
|
+ )
|
|
|
+
|
|
|
+ if not model_schema:
|
|
|
+ raise ValueError(f"Model {model_name} not exist.")
|
|
|
+
|
|
|
+ properties['model_config'] = ModelConfigEntity(
|
|
|
+ provider=copy_app_model_config_dict['model']['provider'],
|
|
|
+ model=copy_app_model_config_dict['model']['name'],
|
|
|
+ model_schema=model_schema,
|
|
|
+ mode=model_mode,
|
|
|
+ provider_model_bundle=provider_model_bundle,
|
|
|
+ credentials=model_credentials,
|
|
|
+ parameters=completion_params,
|
|
|
+ stop=stop,
|
|
|
+ )
|
|
|
+
|
|
|
+ # prompt template
|
|
|
+ prompt_type = PromptTemplateEntity.PromptType.value_of(copy_app_model_config_dict['prompt_type'])
|
|
|
+ if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
|
|
+ simple_prompt_template = copy_app_model_config_dict.get("pre_prompt", "")
|
|
|
+ properties['prompt_template'] = PromptTemplateEntity(
|
|
|
+ prompt_type=prompt_type,
|
|
|
+ simple_prompt_template=simple_prompt_template
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ advanced_chat_prompt_template = None
|
|
|
+ chat_prompt_config = copy_app_model_config_dict.get("chat_prompt_config", {})
|
|
|
+ if chat_prompt_config:
|
|
|
+ chat_prompt_messages = []
|
|
|
+ for message in chat_prompt_config.get("prompt", []):
|
|
|
+ chat_prompt_messages.append({
|
|
|
+ "text": message["text"],
|
|
|
+ "role": PromptMessageRole.value_of(message["role"])
|
|
|
+ })
|
|
|
+
|
|
|
+ advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(
|
|
|
+ messages=chat_prompt_messages
|
|
|
+ )
|
|
|
+
|
|
|
+ advanced_completion_prompt_template = None
|
|
|
+ completion_prompt_config = copy_app_model_config_dict.get("completion_prompt_config", {})
|
|
|
+ if completion_prompt_config:
|
|
|
+ completion_prompt_template_params = {
|
|
|
+ 'prompt': completion_prompt_config['prompt']['text'],
|
|
|
+ }
|
|
|
+
|
|
|
+ if 'conversation_histories_role' in completion_prompt_config:
|
|
|
+ completion_prompt_template_params['role_prefix'] = {
|
|
|
+ 'user': completion_prompt_config['conversation_histories_role']['user_prefix'],
|
|
|
+ 'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix']
|
|
|
+ }
|
|
|
+
|
|
|
+ advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
|
|
|
+ **completion_prompt_template_params
|
|
|
+ )
|
|
|
+
|
|
|
+ properties['prompt_template'] = PromptTemplateEntity(
|
|
|
+ prompt_type=prompt_type,
|
|
|
+ advanced_chat_prompt_template=advanced_chat_prompt_template,
|
|
|
+ advanced_completion_prompt_template=advanced_completion_prompt_template
|
|
|
+ )
|
|
|
+
|
|
|
+ # external data variables
|
|
|
+ properties['external_data_variables'] = []
|
|
|
+ external_data_tools = copy_app_model_config_dict.get('external_data_tools', [])
|
|
|
+ for external_data_tool in external_data_tools:
|
|
|
+ if 'enabled' not in external_data_tool or not external_data_tool['enabled']:
|
|
|
+ continue
|
|
|
+
|
|
|
+ properties['external_data_variables'].append(
|
|
|
+ ExternalDataVariableEntity(
|
|
|
+ variable=external_data_tool['variable'],
|
|
|
+ type=external_data_tool['type'],
|
|
|
+ config=external_data_tool['config']
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ # show retrieve source
|
|
|
+ show_retrieve_source = False
|
|
|
+ retriever_resource_dict = copy_app_model_config_dict.get('retriever_resource')
|
|
|
+ if retriever_resource_dict:
|
|
|
+ if 'enabled' in retriever_resource_dict and retriever_resource_dict['enabled']:
|
|
|
+ show_retrieve_source = True
|
|
|
+
|
|
|
+ properties['show_retrieve_source'] = show_retrieve_source
|
|
|
+
|
|
|
+ if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \
|
|
|
+ and 'enabled' in copy_app_model_config_dict['agent_mode'] and copy_app_model_config_dict['agent_mode'][
|
|
|
+ 'enabled']:
|
|
|
+ agent_dict = copy_app_model_config_dict.get('agent_mode')
|
|
|
+ if agent_dict['strategy'] in ['router', 'react_router']:
|
|
|
+ dataset_ids = []
|
|
|
+ for tool in agent_dict.get('tools', []):
|
|
|
+ key = list(tool.keys())[0]
|
|
|
+
|
|
|
+ if key != 'dataset':
|
|
|
+ continue
|
|
|
+
|
|
|
+ tool_item = tool[key]
|
|
|
+
|
|
|
+ if "enabled" not in tool_item or not tool_item["enabled"]:
|
|
|
+ continue
|
|
|
+
|
|
|
+ dataset_id = tool_item['id']
|
|
|
+ dataset_ids.append(dataset_id)
|
|
|
+
|
|
|
+ dataset_configs = copy_app_model_config_dict.get('dataset_configs', {'retrieval_model': 'single'})
|
|
|
+ query_variable = copy_app_model_config_dict.get('dataset_query_variable')
|
|
|
+ if dataset_configs['retrieval_model'] == 'single':
|
|
|
+ properties['dataset'] = DatasetEntity(
|
|
|
+ dataset_ids=dataset_ids,
|
|
|
+ retrieve_config=DatasetRetrieveConfigEntity(
|
|
|
+ query_variable=query_variable,
|
|
|
+ retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
|
|
+ dataset_configs['retrieval_model']
|
|
|
+ ),
|
|
|
+ single_strategy=agent_dict['strategy']
|
|
|
+ )
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ properties['dataset'] = DatasetEntity(
|
|
|
+ dataset_ids=dataset_ids,
|
|
|
+ retrieve_config=DatasetRetrieveConfigEntity(
|
|
|
+ query_variable=query_variable,
|
|
|
+ retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
|
|
+ dataset_configs['retrieval_model']
|
|
|
+ ),
|
|
|
+ top_k=dataset_configs.get('top_k'),
|
|
|
+ score_threshold=dataset_configs.get('score_threshold'),
|
|
|
+ reranking_model=dataset_configs.get('reranking_model')
|
|
|
+ )
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ if agent_dict['strategy'] == 'react':
|
|
|
+ strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
|
|
+ else:
|
|
|
+ strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
|
|
+
|
|
|
+ agent_tools = []
|
|
|
+ for tool in agent_dict.get('tools', []):
|
|
|
+ key = list(tool.keys())[0]
|
|
|
+ tool_item = tool[key]
|
|
|
+
|
|
|
+ agent_tool_properties = {
|
|
|
+ "tool_id": key
|
|
|
+ }
|
|
|
+
|
|
|
+ if "enabled" not in tool_item or not tool_item["enabled"]:
|
|
|
+ continue
|
|
|
+
|
|
|
+ agent_tool_properties["config"] = tool_item
|
|
|
+ agent_tools.append(AgentToolEntity(**agent_tool_properties))
|
|
|
+
|
|
|
+ properties['agent'] = AgentEntity(
|
|
|
+ provider=properties['model_config'].provider,
|
|
|
+ model=properties['model_config'].model,
|
|
|
+ strategy=strategy,
|
|
|
+ tools=agent_tools
|
|
|
+ )
|
|
|
+
|
|
|
+ # file upload
|
|
|
+ file_upload_dict = copy_app_model_config_dict.get('file_upload')
|
|
|
+ if file_upload_dict:
|
|
|
+ if 'image' in file_upload_dict and file_upload_dict['image']:
|
|
|
+ if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']:
|
|
|
+ properties['file_upload'] = FileUploadEntity(
|
|
|
+ image_config={
|
|
|
+ 'number_limits': file_upload_dict['image']['number_limits'],
|
|
|
+ 'detail': file_upload_dict['image']['detail'],
|
|
|
+ 'transfer_methods': file_upload_dict['image']['transfer_methods']
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ # opening statement
|
|
|
+ properties['opening_statement'] = copy_app_model_config_dict.get('opening_statement')
|
|
|
+
|
|
|
+ # suggested questions after answer
|
|
|
+ suggested_questions_after_answer_dict = copy_app_model_config_dict.get('suggested_questions_after_answer')
|
|
|
+ if suggested_questions_after_answer_dict:
|
|
|
+ if 'enabled' in suggested_questions_after_answer_dict and suggested_questions_after_answer_dict['enabled']:
|
|
|
+ properties['suggested_questions_after_answer'] = True
|
|
|
+
|
|
|
+ # more like this
|
|
|
+ more_like_this_dict = copy_app_model_config_dict.get('more_like_this')
|
|
|
+ if more_like_this_dict:
|
|
|
+ if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']:
|
|
|
+ properties['more_like_this'] = copy_app_model_config_dict.get('opening_statement')
|
|
|
+
|
|
|
+ # speech to text
|
|
|
+ speech_to_text_dict = copy_app_model_config_dict.get('speech_to_text')
|
|
|
+ if speech_to_text_dict:
|
|
|
+ if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']:
|
|
|
+ properties['speech_to_text'] = True
|
|
|
+
|
|
|
+ # sensitive word avoidance
|
|
|
+ sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance')
|
|
|
+ if sensitive_word_avoidance_dict:
|
|
|
+ if 'enabled' in sensitive_word_avoidance_dict and sensitive_word_avoidance_dict['enabled']:
|
|
|
+ properties['sensitive_word_avoidance'] = SensitiveWordAvoidanceEntity(
|
|
|
+ type=sensitive_word_avoidance_dict.get('type'),
|
|
|
+ config=sensitive_word_avoidance_dict.get('config'),
|
|
|
+ )
|
|
|
+
|
|
|
+ return AppOrchestrationConfigEntity(**properties)
|
|
|
+
|
|
|
+ def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \
|
|
|
+ -> Tuple[Conversation, Message]:
|
|
|
+ """
|
|
|
+ Initialize generate records
|
|
|
+ :param application_generate_entity: application generate entity
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
|
|
|
+
|
|
|
+ model_type_instance = app_orchestration_config_entity.model_config.provider_model_bundle.model_type_instance
|
|
|
+ model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
|
|
+ model_schema = model_type_instance.get_model_schema(
|
|
|
+ model=app_orchestration_config_entity.model_config.model,
|
|
|
+ credentials=app_orchestration_config_entity.model_config.credentials
|
|
|
+ )
|
|
|
+
|
|
|
+ app_record = (db.session.query(App)
|
|
|
+ .filter(App.id == application_generate_entity.app_id).first())
|
|
|
+
|
|
|
+ app_mode = app_record.mode
|
|
|
+
|
|
|
+ # get from source
|
|
|
+ end_user_id = None
|
|
|
+ account_id = None
|
|
|
+ if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
|
|
+ from_source = 'api'
|
|
|
+ end_user_id = application_generate_entity.user_id
|
|
|
+ else:
|
|
|
+ from_source = 'console'
|
|
|
+ account_id = application_generate_entity.user_id
|
|
|
+
|
|
|
+ override_model_configs = None
|
|
|
+ if application_generate_entity.app_model_config_override:
|
|
|
+ override_model_configs = application_generate_entity.app_model_config_dict
|
|
|
+
|
|
|
+ introduction = ''
|
|
|
+ if app_mode == 'chat':
|
|
|
+ # get conversation introduction
|
|
|
+ introduction = self._get_conversation_introduction(application_generate_entity)
|
|
|
+
|
|
|
+ if not application_generate_entity.conversation_id:
|
|
|
+ conversation = Conversation(
|
|
|
+ app_id=app_record.id,
|
|
|
+ app_model_config_id=application_generate_entity.app_model_config_id,
|
|
|
+ model_provider=app_orchestration_config_entity.model_config.provider,
|
|
|
+ model_id=app_orchestration_config_entity.model_config.model,
|
|
|
+ override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
|
|
+ mode=app_mode,
|
|
|
+ name='New conversation',
|
|
|
+ inputs=application_generate_entity.inputs,
|
|
|
+ introduction=introduction,
|
|
|
+ system_instruction="",
|
|
|
+ system_instruction_tokens=0,
|
|
|
+ status='normal',
|
|
|
+ from_source=from_source,
|
|
|
+ from_end_user_id=end_user_id,
|
|
|
+ from_account_id=account_id,
|
|
|
+ )
|
|
|
+
|
|
|
+ db.session.add(conversation)
|
|
|
+ db.session.commit()
|
|
|
+ else:
|
|
|
+ conversation = (
|
|
|
+ db.session.query(Conversation)
|
|
|
+ .filter(
|
|
|
+ Conversation.id == application_generate_entity.conversation_id,
|
|
|
+ Conversation.app_id == app_record.id
|
|
|
+ ).first()
|
|
|
+ )
|
|
|
+
|
|
|
+ currency = model_schema.pricing.currency if model_schema.pricing else 'USD'
|
|
|
+
|
|
|
+ message = Message(
|
|
|
+ app_id=app_record.id,
|
|
|
+ model_provider=app_orchestration_config_entity.model_config.provider,
|
|
|
+ model_id=app_orchestration_config_entity.model_config.model,
|
|
|
+ override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
|
|
+ conversation_id=conversation.id,
|
|
|
+ inputs=application_generate_entity.inputs,
|
|
|
+ query=application_generate_entity.query or "",
|
|
|
+ message="",
|
|
|
+ message_tokens=0,
|
|
|
+ message_unit_price=0,
|
|
|
+ message_price_unit=0,
|
|
|
+ answer="",
|
|
|
+ answer_tokens=0,
|
|
|
+ answer_unit_price=0,
|
|
|
+ answer_price_unit=0,
|
|
|
+ provider_response_latency=0,
|
|
|
+ total_price=0,
|
|
|
+ currency=currency,
|
|
|
+ from_source=from_source,
|
|
|
+ from_end_user_id=end_user_id,
|
|
|
+ from_account_id=account_id,
|
|
|
+ agent_based=app_orchestration_config_entity.agent is not None
|
|
|
+ )
|
|
|
+
|
|
|
+ db.session.add(message)
|
|
|
+ db.session.commit()
|
|
|
+
|
|
|
+ for file in application_generate_entity.files:
|
|
|
+ message_file = MessageFile(
|
|
|
+ message_id=message.id,
|
|
|
+ type=file.type.value,
|
|
|
+ transfer_method=file.transfer_method.value,
|
|
|
+ url=file.url,
|
|
|
+ upload_file_id=file.upload_file_id,
|
|
|
+ created_by_role=('account' if account_id else 'end_user'),
|
|
|
+ created_by=account_id or end_user_id,
|
|
|
+ )
|
|
|
+ db.session.add(message_file)
|
|
|
+ db.session.commit()
|
|
|
+
|
|
|
+ return conversation, message
|
|
|
+
|
|
|
+ def _get_conversation_introduction(self, application_generate_entity: ApplicationGenerateEntity) -> str:
|
|
|
+ """
|
|
|
+ Get conversation introduction
|
|
|
+ :param application_generate_entity: application generate entity
|
|
|
+ :return: conversation introduction
|
|
|
+ """
|
|
|
+ app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
|
|
|
+ introduction = app_orchestration_config_entity.opening_statement
|
|
|
+
|
|
|
+ if introduction:
|
|
|
+ try:
|
|
|
+ inputs = application_generate_entity.inputs
|
|
|
+ prompt_template = PromptTemplateParser(template=introduction)
|
|
|
+ prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
|
|
+ introduction = prompt_template.format(prompt_inputs)
|
|
|
+ except KeyError:
|
|
|
+ pass
|
|
|
+
|
|
|
+ return introduction
|
|
|
+
|
|
|
+ def _get_conversation(self, conversation_id: str) -> Conversation:
|
|
|
+ """
|
|
|
+ Get conversation by conversation id
|
|
|
+ :param conversation_id: conversation id
|
|
|
+ :return: conversation
|
|
|
+ """
|
|
|
+ conversation = (
|
|
|
+ db.session.query(Conversation)
|
|
|
+ .filter(Conversation.id == conversation_id)
|
|
|
+ .first()
|
|
|
+ )
|
|
|
+
|
|
|
+ return conversation
|
|
|
+
|
|
|
+ def _get_message(self, message_id: str) -> Message:
|
|
|
+ """
|
|
|
+ Get message by message id
|
|
|
+ :param message_id: message id
|
|
|
+ :return: message
|
|
|
+ """
|
|
|
+ message = (
|
|
|
+ db.session.query(Message)
|
|
|
+ .filter(Message.id == message_id)
|
|
|
+ .first()
|
|
|
+ )
|
|
|
+
|
|
|
+ return message
|