|
@@ -8,6 +8,8 @@ from typing import Union
|
|
|
|
|
|
from flask import Flask, current_app
|
|
|
from pydantic import ValidationError
|
|
|
+from sqlalchemy import select
|
|
|
+from sqlalchemy.orm import Session
|
|
|
|
|
|
import contexts
|
|
|
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
|
@@ -18,15 +20,20 @@ from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGe
|
|
|
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
|
|
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
|
|
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
|
|
-from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
|
|
+from core.app.entities.app_invoke_entities import (
|
|
|
+ AdvancedChatAppGenerateEntity,
|
|
|
+ InvokeFrom,
|
|
|
+)
|
|
|
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
|
|
from core.file.message_file_parser import MessageFileParser
|
|
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
|
|
from core.ops.ops_trace_manager import TraceQueueManager
|
|
|
+from core.workflow.entities.variable_pool import VariablePool
|
|
|
+from core.workflow.enums import SystemVariable
|
|
|
from extensions.ext_database import db
|
|
|
from models.account import Account
|
|
|
from models.model import App, Conversation, EndUser, Message
|
|
|
-from models.workflow import Workflow
|
|
|
+from models.workflow import ConversationVariable, Workflow
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
@@ -120,7 +127,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|
|
conversation=conversation,
|
|
|
stream=stream
|
|
|
)
|
|
|
-
|
|
|
+
|
|
|
def single_iteration_generate(self, app_model: App,
|
|
|
workflow: Workflow,
|
|
|
node_id: str,
|
|
@@ -140,10 +147,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|
|
"""
|
|
|
if not node_id:
|
|
|
raise ValueError('node_id is required')
|
|
|
-
|
|
|
+
|
|
|
if args.get('inputs') is None:
|
|
|
raise ValueError('inputs is required')
|
|
|
-
|
|
|
+
|
|
|
extras = {
|
|
|
"auto_generate_conversation_name": False
|
|
|
}
|
|
@@ -209,7 +216,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|
|
# update conversation features
|
|
|
conversation.override_model_configs = workflow.features
|
|
|
db.session.commit()
|
|
|
- db.session.refresh(conversation)
|
|
|
+ # db.session.refresh(conversation)
|
|
|
|
|
|
# init queue manager
|
|
|
queue_manager = MessageBasedAppQueueManager(
|
|
@@ -221,15 +228,69 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|
|
message_id=message.id
|
|
|
)
|
|
|
|
|
|
+ # Init conversation variables
|
|
|
+ stmt = select(ConversationVariable).where(
|
|
|
+ ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
|
|
|
+ )
|
|
|
+ with Session(db.engine) as session:
|
|
|
+ conversation_variables = session.scalars(stmt).all()
|
|
|
+ if not conversation_variables:
|
|
|
+ # Create conversation variables if they don't exist.
|
|
|
+ conversation_variables = [
|
|
|
+ ConversationVariable.from_variable(
|
|
|
+ app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
|
|
|
+ )
|
|
|
+ for variable in workflow.conversation_variables
|
|
|
+ ]
|
|
|
+ session.add_all(conversation_variables)
|
|
|
+ # Convert database entities to variables.
|
|
|
+ conversation_variables = [item.to_variable() for item in conversation_variables]
|
|
|
+
|
|
|
+ session.commit()
|
|
|
+
|
|
|
+ # Increment dialogue count.
|
|
|
+ conversation.dialogue_count += 1
|
|
|
+
|
|
|
+ conversation_id = conversation.id
|
|
|
+ conversation_dialogue_count = conversation.dialogue_count
|
|
|
+ db.session.commit()
|
|
|
+ db.session.refresh(conversation)
|
|
|
+
|
|
|
+ inputs = application_generate_entity.inputs
|
|
|
+ query = application_generate_entity.query
|
|
|
+ files = application_generate_entity.files
|
|
|
+
|
|
|
+ user_id = None
|
|
|
+ if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
|
|
+ end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
|
|
|
+ if end_user:
|
|
|
+ user_id = end_user.session_id
|
|
|
+ else:
|
|
|
+ user_id = application_generate_entity.user_id
|
|
|
+
|
|
|
+ # Create a variable pool.
|
|
|
+ system_inputs = {
|
|
|
+ SystemVariable.QUERY: query,
|
|
|
+ SystemVariable.FILES: files,
|
|
|
+ SystemVariable.CONVERSATION_ID: conversation_id,
|
|
|
+ SystemVariable.USER_ID: user_id,
|
|
|
+ SystemVariable.DIALOGUE_COUNT: conversation_dialogue_count,
|
|
|
+ }
|
|
|
+ variable_pool = VariablePool(
|
|
|
+ system_variables=system_inputs,
|
|
|
+ user_inputs=inputs,
|
|
|
+ environment_variables=workflow.environment_variables,
|
|
|
+ conversation_variables=conversation_variables,
|
|
|
+ )
|
|
|
+ contexts.workflow_variable_pool.set(variable_pool)
|
|
|
+
|
|
|
# new thread
|
|
|
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
|
|
'flask_app': current_app._get_current_object(),
|
|
|
'application_generate_entity': application_generate_entity,
|
|
|
'queue_manager': queue_manager,
|
|
|
- 'conversation_id': conversation.id,
|
|
|
'message_id': message.id,
|
|
|
- 'user': user,
|
|
|
- 'context': contextvars.copy_context()
|
|
|
+ 'context': contextvars.copy_context(),
|
|
|
})
|
|
|
|
|
|
worker_thread.start()
|
|
@@ -242,7 +303,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|
|
conversation=conversation,
|
|
|
message=message,
|
|
|
user=user,
|
|
|
- stream=stream
|
|
|
+ stream=stream,
|
|
|
)
|
|
|
|
|
|
return AdvancedChatAppGenerateResponseConverter.convert(
|
|
@@ -253,9 +314,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|
|
def _generate_worker(self, flask_app: Flask,
|
|
|
application_generate_entity: AdvancedChatAppGenerateEntity,
|
|
|
queue_manager: AppQueueManager,
|
|
|
- conversation_id: str,
|
|
|
message_id: str,
|
|
|
- user: Account,
|
|
|
context: contextvars.Context) -> None:
|
|
|
"""
|
|
|
Generate worker in a new thread.
|
|
@@ -282,8 +341,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|
|
user_id=application_generate_entity.user_id
|
|
|
)
|
|
|
else:
|
|
|
- # get conversation and message
|
|
|
- conversation = self._get_conversation(conversation_id)
|
|
|
+ # get message
|
|
|
message = self._get_message(message_id)
|
|
|
|
|
|
# chatbot app
|
|
@@ -291,7 +349,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|
|
runner.run(
|
|
|
application_generate_entity=application_generate_entity,
|
|
|
queue_manager=queue_manager,
|
|
|
- conversation=conversation,
|
|
|
message=message
|
|
|
)
|
|
|
except GenerateTaskStoppedException:
|
|
@@ -314,14 +371,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|
|
finally:
|
|
|
db.session.close()
|
|
|
|
|
|
- def _handle_advanced_chat_response(self, application_generate_entity: AdvancedChatAppGenerateEntity,
|
|
|
- workflow: Workflow,
|
|
|
- queue_manager: AppQueueManager,
|
|
|
- conversation: Conversation,
|
|
|
- message: Message,
|
|
|
- user: Union[Account, EndUser],
|
|
|
- stream: bool = False) \
|
|
|
- -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
|
|
+ def _handle_advanced_chat_response(
|
|
|
+ self,
|
|
|
+ *,
|
|
|
+ application_generate_entity: AdvancedChatAppGenerateEntity,
|
|
|
+ workflow: Workflow,
|
|
|
+ queue_manager: AppQueueManager,
|
|
|
+ conversation: Conversation,
|
|
|
+ message: Message,
|
|
|
+ user: Union[Account, EndUser],
|
|
|
+ stream: bool = False,
|
|
|
+ ) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
|
|
"""
|
|
|
Handle response.
|
|
|
:param application_generate_entity: application generate entity
|
|
@@ -341,7 +401,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|
|
conversation=conversation,
|
|
|
message=message,
|
|
|
user=user,
|
|
|
- stream=stream
|
|
|
+ stream=stream,
|
|
|
)
|
|
|
|
|
|
try:
|