Bladeren bron

refactor: optimize database usage (#12071)

Signed-off-by: -LAN- <laipz8200@outlook.com>
-LAN- 4 maanden geleden
bovenliggende
commit
83ea931e3c

+ 180 - 172
api/core/app/apps/advanced_chat/generate_task_pipeline.py

@@ -5,6 +5,9 @@ from collections.abc import Generator, Mapping
 from threading import Thread
 from typing import Any, Optional, Union
 
+from sqlalchemy import select
+from sqlalchemy.orm import Session
+
 from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
 from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
 from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@@ -79,8 +82,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
 
     _task_state: WorkflowTaskState
     _application_generate_entity: AdvancedChatAppGenerateEntity
-    _workflow: Workflow
-    _user: Union[Account, EndUser]
     _workflow_system_variables: dict[SystemVariableKey, Any]
     _wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
     _conversation_name_generate_thread: Optional[Thread] = None
@@ -96,32 +97,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
         stream: bool,
         dialogue_count: int,
     ) -> None:
-        """
-        Initialize AdvancedChatAppGenerateTaskPipeline.
-        :param application_generate_entity: application generate entity
-        :param workflow: workflow
-        :param queue_manager: queue manager
-        :param conversation: conversation
-        :param message: message
-        :param user: user
-        :param stream: stream
-        :param dialogue_count: dialogue count
-        """
-        super().__init__(application_generate_entity, queue_manager, user, stream)
+        super().__init__(
+            application_generate_entity=application_generate_entity,
+            queue_manager=queue_manager,
+            stream=stream,
+        )
 
-        if isinstance(self._user, EndUser):
-            user_id = self._user.session_id
+        if isinstance(user, EndUser):
+            self._user_id = user.session_id
+            self._created_by_role = CreatedByRole.END_USER
+        elif isinstance(user, Account):
+            self._user_id = user.id
+            self._created_by_role = CreatedByRole.ACCOUNT
         else:
-            user_id = self._user.id
+            raise NotImplementedError(f"User type not supported: {type(user)}")
+
+        self._workflow_id = workflow.id
+        self._workflow_features_dict = workflow.features_dict
+
+        self._conversation_id = conversation.id
+        self._conversation_mode = conversation.mode
+
+        self._message_id = message.id
+        self._message_created_at = int(message.created_at.timestamp())
 
-        self._workflow = workflow
-        self._conversation = conversation
-        self._message = message
         self._workflow_system_variables = {
             SystemVariableKey.QUERY: message.query,
             SystemVariableKey.FILES: application_generate_entity.files,
             SystemVariableKey.CONVERSATION_ID: conversation.id,
-            SystemVariableKey.USER_ID: user_id,
+            SystemVariableKey.USER_ID: self._user_id,
             SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
             SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
             SystemVariableKey.WORKFLOW_ID: workflow.id,
@@ -139,13 +143,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
         Process generate task pipeline.
         :return:
         """
-        db.session.refresh(self._workflow)
-        db.session.refresh(self._user)
-        db.session.close()
-
         # start generate conversation name thread
         self._conversation_name_generate_thread = self._generate_conversation_name(
-            self._conversation, self._application_generate_entity.query
+            conversation_id=self._conversation_id, query=self._application_generate_entity.query
         )
 
         generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
@@ -171,12 +171,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 return ChatbotAppBlockingResponse(
                     task_id=stream_response.task_id,
                     data=ChatbotAppBlockingResponse.Data(
-                        id=self._message.id,
-                        mode=self._conversation.mode,
-                        conversation_id=self._conversation.id,
-                        message_id=self._message.id,
+                        id=self._message_id,
+                        mode=self._conversation_mode,
+                        conversation_id=self._conversation_id,
+                        message_id=self._message_id,
                         answer=self._task_state.answer,
-                        created_at=int(self._message.created_at.timestamp()),
+                        created_at=self._message_created_at,
                         **extras,
                     ),
                 )
@@ -194,9 +194,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
         """
         for stream_response in generator:
             yield ChatbotAppStreamResponse(
-                conversation_id=self._conversation.id,
-                message_id=self._message.id,
-                created_at=int(self._message.created_at.timestamp()),
+                conversation_id=self._conversation_id,
+                message_id=self._message_id,
+                created_at=self._message_created_at,
                 stream_response=stream_response,
             )
 
@@ -214,7 +214,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
         tts_publisher = None
         task_id = self._application_generate_entity.task_id
         tenant_id = self._application_generate_entity.app_config.tenant_id
-        features_dict = self._workflow.features_dict
+        features_dict = self._workflow_features_dict
 
         if (
             features_dict.get("text_to_speech")
@@ -274,26 +274,33 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
             if isinstance(event, QueuePingEvent):
                 yield self._ping_stream_response()
             elif isinstance(event, QueueErrorEvent):
-                err = self._handle_error(event, self._message)
+                with Session(db.engine) as session:
+                    err = self._handle_error(event=event, session=session, message_id=self._message_id)
+                    session.commit()
                 yield self._error_to_stream_response(err)
                 break
             elif isinstance(event, QueueWorkflowStartedEvent):
                 # override graph runtime state
                 graph_runtime_state = event.graph_runtime_state
 
-                # init workflow run
-                workflow_run = self._handle_workflow_run_start()
-
-                self._refetch_message()
-                self._message.workflow_run_id = workflow_run.id
-
-                db.session.commit()
-                db.session.refresh(self._message)
-                db.session.close()
-
-                yield self._workflow_start_to_stream_response(
-                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
-                )
+                with Session(db.engine) as session:
+                    # init workflow run
+                    workflow_run = self._handle_workflow_run_start(
+                        session=session,
+                        workflow_id=self._workflow_id,
+                        user_id=self._user_id,
+                        created_by_role=self._created_by_role,
+                    )
+                    message = self._get_message(session=session)
+                    if not message:
+                        raise ValueError(f"Message not found: {self._message_id}")
+                    message.workflow_run_id = workflow_run.id
+                    session.commit()
+
+                    workflow_start_resp = self._workflow_start_to_stream_response(
+                        session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
+                    )
+                yield workflow_start_resp
             elif isinstance(
                 event,
                 QueueNodeRetryEvent,
@@ -304,28 +311,28 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                     workflow_run=workflow_run, event=event
                 )
 
-                response = self._workflow_node_retry_to_stream_response(
+                node_retry_resp = self._workflow_node_retry_to_stream_response(
                     event=event,
                     task_id=self._application_generate_entity.task_id,
                     workflow_node_execution=workflow_node_execution,
                 )
 
-                if response:
-                    yield response
+                if node_retry_resp:
+                    yield node_retry_resp
             elif isinstance(event, QueueNodeStartedEvent):
                 if not workflow_run:
                     raise ValueError("workflow run not initialized.")
 
                 workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
 
-                response_start = self._workflow_node_start_to_stream_response(
+                node_start_resp = self._workflow_node_start_to_stream_response(
                     event=event,
                     task_id=self._application_generate_entity.task_id,
                     workflow_node_execution=workflow_node_execution,
                 )
 
-                if response_start:
-                    yield response_start
+                if node_start_resp:
+                    yield node_start_resp
             elif isinstance(event, QueueNodeSucceededEvent):
                 workflow_node_execution = self._handle_workflow_node_execution_success(event)
 
@@ -333,25 +340,24 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 if event.node_type in [NodeType.ANSWER, NodeType.END]:
                     self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
 
-                response_finish = self._workflow_node_finish_to_stream_response(
+                node_finish_resp = self._workflow_node_finish_to_stream_response(
                     event=event,
                     task_id=self._application_generate_entity.task_id,
                     workflow_node_execution=workflow_node_execution,
                 )
 
-                if response_finish:
-                    yield response_finish
+                if node_finish_resp:
+                    yield node_finish_resp
             elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
                 workflow_node_execution = self._handle_workflow_node_execution_failed(event)
 
-                response_finish = self._workflow_node_finish_to_stream_response(
+                node_finish_resp = self._workflow_node_finish_to_stream_response(
                     event=event,
                     task_id=self._application_generate_entity.task_id,
                     workflow_node_execution=workflow_node_execution,
                 )
-
-                if response:
-                    yield response
+                if node_finish_resp:
+                    yield node_finish_resp
 
             elif isinstance(event, QueueParallelBranchRunStartedEvent):
                 if not workflow_run:
@@ -395,20 +401,24 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 if not graph_runtime_state:
                     raise ValueError("workflow run not initialized.")
 
-                workflow_run = self._handle_workflow_run_success(
-                    workflow_run=workflow_run,
-                    start_at=graph_runtime_state.start_at,
-                    total_tokens=graph_runtime_state.total_tokens,
-                    total_steps=graph_runtime_state.node_run_steps,
-                    outputs=event.outputs,
-                    conversation_id=self._conversation.id,
-                    trace_manager=trace_manager,
-                )
+                with Session(db.engine) as session:
+                    workflow_run = self._handle_workflow_run_success(
+                        session=session,
+                        workflow_run=workflow_run,
+                        start_at=graph_runtime_state.start_at,
+                        total_tokens=graph_runtime_state.total_tokens,
+                        total_steps=graph_runtime_state.node_run_steps,
+                        outputs=event.outputs,
+                        conversation_id=self._conversation_id,
+                        trace_manager=trace_manager,
+                    )
 
-                yield self._workflow_finish_to_stream_response(
-                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
-                )
+                    workflow_finish_resp = self._workflow_finish_to_stream_response(
+                        session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
+                    )
+                    session.commit()
 
+                yield workflow_finish_resp
                 self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
             elif isinstance(event, QueueWorkflowPartialSuccessEvent):
                 if not workflow_run:
@@ -417,21 +427,25 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 if not graph_runtime_state:
                     raise ValueError("graph runtime state not initialized.")
 
-                workflow_run = self._handle_workflow_run_partial_success(
-                    workflow_run=workflow_run,
-                    start_at=graph_runtime_state.start_at,
-                    total_tokens=graph_runtime_state.total_tokens,
-                    total_steps=graph_runtime_state.node_run_steps,
-                    outputs=event.outputs,
-                    exceptions_count=event.exceptions_count,
-                    conversation_id=None,
-                    trace_manager=trace_manager,
-                )
+                with Session(db.engine) as session:
+                    workflow_run = self._handle_workflow_run_partial_success(
+                        session=session,
+                        workflow_run=workflow_run,
+                        start_at=graph_runtime_state.start_at,
+                        total_tokens=graph_runtime_state.total_tokens,
+                        total_steps=graph_runtime_state.node_run_steps,
+                        outputs=event.outputs,
+                        exceptions_count=event.exceptions_count,
+                        conversation_id=None,
+                        trace_manager=trace_manager,
+                    )
 
-                yield self._workflow_finish_to_stream_response(
-                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
-                )
+                    workflow_finish_resp = self._workflow_finish_to_stream_response(
+                        session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
+                    )
+                    session.commit()
 
+                yield workflow_finish_resp
                 self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
             elif isinstance(event, QueueWorkflowFailedEvent):
                 if not workflow_run:
@@ -440,71 +454,73 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 if not graph_runtime_state:
                     raise ValueError("graph runtime state not initialized.")
 
-                workflow_run = self._handle_workflow_run_failed(
-                    workflow_run=workflow_run,
-                    start_at=graph_runtime_state.start_at,
-                    total_tokens=graph_runtime_state.total_tokens,
-                    total_steps=graph_runtime_state.node_run_steps,
-                    status=WorkflowRunStatus.FAILED,
-                    error=event.error,
-                    conversation_id=self._conversation.id,
-                    trace_manager=trace_manager,
-                    exceptions_count=event.exceptions_count,
-                )
-
-                yield self._workflow_finish_to_stream_response(
-                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
-                )
-
-                err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
-                yield self._error_to_stream_response(self._handle_error(err_event, self._message))
-                break
-            elif isinstance(event, QueueStopEvent):
-                if workflow_run and graph_runtime_state:
+                with Session(db.engine) as session:
                     workflow_run = self._handle_workflow_run_failed(
+                        session=session,
                         workflow_run=workflow_run,
                         start_at=graph_runtime_state.start_at,
                         total_tokens=graph_runtime_state.total_tokens,
                         total_steps=graph_runtime_state.node_run_steps,
-                        status=WorkflowRunStatus.STOPPED,
-                        error=event.get_stop_reason(),
-                        conversation_id=self._conversation.id,
+                        status=WorkflowRunStatus.FAILED,
+                        error=event.error,
+                        conversation_id=self._conversation_id,
                         trace_manager=trace_manager,
+                        exceptions_count=event.exceptions_count,
                     )
-
-                    yield self._workflow_finish_to_stream_response(
-                        task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
+                    workflow_finish_resp = self._workflow_finish_to_stream_response(
+                        session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
                     )
-
-                # Save message
-                self._save_message(graph_runtime_state=graph_runtime_state)
+                    err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
+                    err = self._handle_error(event=err_event, session=session, message_id=self._message_id)
+                    session.commit()
+                yield workflow_finish_resp
+                yield self._error_to_stream_response(err)
+                break
+            elif isinstance(event, QueueStopEvent):
+                if workflow_run and graph_runtime_state:
+                    with Session(db.engine) as session:
+                        workflow_run = self._handle_workflow_run_failed(
+                            session=session,
+                            workflow_run=workflow_run,
+                            start_at=graph_runtime_state.start_at,
+                            total_tokens=graph_runtime_state.total_tokens,
+                            total_steps=graph_runtime_state.node_run_steps,
+                            status=WorkflowRunStatus.STOPPED,
+                            error=event.get_stop_reason(),
+                            conversation_id=self._conversation_id,
+                            trace_manager=trace_manager,
+                        )
+
+                        workflow_finish_resp = self._workflow_finish_to_stream_response(
+                            session=session,
+                            task_id=self._application_generate_entity.task_id,
+                            workflow_run=workflow_run,
+                        )
+                        # Save message
+                        self._save_message(session=session, graph_runtime_state=graph_runtime_state)
+                        session.commit()
+                    yield workflow_finish_resp
 
                 yield self._message_end_to_stream_response()
                 break
             elif isinstance(event, QueueRetrieverResourcesEvent):
                 self._handle_retriever_resources(event)
 
-                self._refetch_message()
-
-                self._message.message_metadata = (
-                    json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
-                )
-
-                db.session.commit()
-                db.session.refresh(self._message)
-                db.session.close()
+                with Session(db.engine) as session:
+                    message = self._get_message(session=session)
+                    message.message_metadata = (
+                        json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
+                    )
+                    session.commit()
             elif isinstance(event, QueueAnnotationReplyEvent):
                 self._handle_annotation_reply(event)
 
-                self._refetch_message()
-
-                self._message.message_metadata = (
-                    json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
-                )
-
-                db.session.commit()
-                db.session.refresh(self._message)
-                db.session.close()
+                with Session(db.engine) as session:
+                    message = self._get_message(session=session)
+                    message.message_metadata = (
+                        json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
+                    )
+                    session.commit()
             elif isinstance(event, QueueTextChunkEvent):
                 delta_text = event.text
                 if delta_text is None:
@@ -521,7 +537,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
 
                 self._task_state.answer += delta_text
                 yield self._message_to_stream_response(
-                    answer=delta_text, message_id=self._message.id, from_variable_selector=event.from_variable_selector
+                    answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
                 )
             elif isinstance(event, QueueMessageReplaceEvent):
                 # published by moderation
@@ -536,7 +552,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                     yield self._message_replace_to_stream_response(answer=output_moderation_answer)
 
                 # Save message
-                self._save_message(graph_runtime_state=graph_runtime_state)
+                with Session(db.engine) as session:
+                    self._save_message(session=session, graph_runtime_state=graph_runtime_state)
+                    session.commit()
 
                 yield self._message_end_to_stream_response()
             else:
@@ -549,54 +567,46 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
         if self._conversation_name_generate_thread:
             self._conversation_name_generate_thread.join()
 
-    def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
-        self._refetch_message()
-
-        self._message.answer = self._task_state.answer
-        self._message.provider_response_latency = time.perf_counter() - self._start_at
-        self._message.message_metadata = (
+    def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
+        message = self._get_message(session=session)
+        message.answer = self._task_state.answer
+        message.provider_response_latency = time.perf_counter() - self._start_at
+        message.message_metadata = (
             json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
         )
         message_files = [
             MessageFile(
-                message_id=self._message.id,
+                message_id=message.id,
                 type=file["type"],
                 transfer_method=file["transfer_method"],
                 url=file["remote_url"],
                 belongs_to="assistant",
                 upload_file_id=file["related_id"],
                 created_by_role=CreatedByRole.ACCOUNT
-                if self._message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
+                if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
                 else CreatedByRole.END_USER,
-                created_by=self._message.from_account_id or self._message.from_end_user_id or "",
+                created_by=message.from_account_id or message.from_end_user_id or "",
             )
             for file in self._recorded_files
         ]
-        db.session.add_all(message_files)
+        session.add_all(message_files)
 
         if graph_runtime_state and graph_runtime_state.llm_usage:
             usage = graph_runtime_state.llm_usage
-            self._message.message_tokens = usage.prompt_tokens
-            self._message.message_unit_price = usage.prompt_unit_price
-            self._message.message_price_unit = usage.prompt_price_unit
-            self._message.answer_tokens = usage.completion_tokens
-            self._message.answer_unit_price = usage.completion_unit_price
-            self._message.answer_price_unit = usage.completion_price_unit
-            self._message.total_price = usage.total_price
-            self._message.currency = usage.currency
-
+            message.message_tokens = usage.prompt_tokens
+            message.message_unit_price = usage.prompt_unit_price
+            message.message_price_unit = usage.prompt_price_unit
+            message.answer_tokens = usage.completion_tokens
+            message.answer_unit_price = usage.completion_unit_price
+            message.answer_price_unit = usage.completion_price_unit
+            message.total_price = usage.total_price
+            message.currency = usage.currency
             self._task_state.metadata["usage"] = jsonable_encoder(usage)
         else:
             self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage())
-
-        db.session.commit()
-
         message_was_created.send(
-            self._message,
+            message,
             application_generate_entity=self._application_generate_entity,
-            conversation=self._conversation,
-            is_first_message=self._application_generate_entity.conversation_id is None,
-            extras=self._application_generate_entity.extras,
         )
 
     def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
@@ -613,7 +623,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
 
         return MessageEndStreamResponse(
             task_id=self._application_generate_entity.task_id,
-            id=self._message.id,
+            id=self._message_id,
             files=self._recorded_files,
             metadata=extras.get("metadata", {}),
         )
@@ -641,11 +651,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
 
         return False
 
-    def _refetch_message(self) -> None:
-        """
-        Refetch message.
-        :return:
-        """
-        message = db.session.query(Message).filter(Message.id == self._message.id).first()
-        if message:
-            self._message = message
+    def _get_message(self, *, session: Session):
+        stmt = select(Message).where(Message.id == self._message_id)
+        message = session.scalar(stmt)
+        if not message:
+            raise ValueError(f"Message not found: {self._message_id}")
+        return message

+ 0 - 1
api/core/app/apps/message_based_app_generator.py

@@ -70,7 +70,6 @@ class MessageBasedAppGenerator(BaseAppGenerator):
             queue_manager=queue_manager,
             conversation=conversation,
             message=message,
-            user=user,
             stream=stream,
         )
 

+ 106 - 86
api/core/app/apps/workflow/generate_task_pipeline.py

@@ -3,6 +3,8 @@ import time
 from collections.abc import Generator
 from typing import Any, Optional, Union
 
+from sqlalchemy.orm import Session
+
 from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
 from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
 from core.app.apps.base_app_queue_manager import AppQueueManager
@@ -50,6 +52,7 @@ from core.ops.ops_trace_manager import TraceQueueManager
 from core.workflow.enums import SystemVariableKey
 from extensions.ext_database import db
 from models.account import Account
+from models.enums import CreatedByRole
 from models.model import EndUser
 from models.workflow import (
     Workflow,
@@ -68,8 +71,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
     WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
     """
 
-    _workflow: Workflow
-    _user: Union[Account, EndUser]
     _task_state: WorkflowTaskState
     _application_generate_entity: WorkflowAppGenerateEntity
     _workflow_system_variables: dict[SystemVariableKey, Any]
@@ -83,25 +84,27 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
         user: Union[Account, EndUser],
         stream: bool,
     ) -> None:
-        """
-        Initialize GenerateTaskPipeline.
-        :param application_generate_entity: application generate entity
-        :param workflow: workflow
-        :param queue_manager: queue manager
-        :param user: user
-        :param stream: is streamed
-        """
-        super().__init__(application_generate_entity, queue_manager, user, stream)
+        super().__init__(
+            application_generate_entity=application_generate_entity,
+            queue_manager=queue_manager,
+            stream=stream,
+        )
 
-        if isinstance(self._user, EndUser):
-            user_id = self._user.session_id
+        if isinstance(user, EndUser):
+            self._user_id = user.session_id
+            self._created_by_role = CreatedByRole.END_USER
+        elif isinstance(user, Account):
+            self._user_id = user.id
+            self._created_by_role = CreatedByRole.ACCOUNT
         else:
-            user_id = self._user.id
+            raise ValueError(f"Invalid user type: {type(user)}")
+
+        self._workflow_id = workflow.id
+        self._workflow_features_dict = workflow.features_dict
 
-        self._workflow = workflow
         self._workflow_system_variables = {
             SystemVariableKey.FILES: application_generate_entity.files,
-            SystemVariableKey.USER_ID: user_id,
+            SystemVariableKey.USER_ID: self._user_id,
             SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
             SystemVariableKey.WORKFLOW_ID: workflow.id,
             SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
@@ -115,10 +118,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
         Process generate task pipeline.
         :return:
         """
-        db.session.refresh(self._workflow)
-        db.session.refresh(self._user)
-        db.session.close()
-
         generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
         if self._stream:
             return self._to_stream_response(generator)
@@ -185,7 +184,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
         tts_publisher = None
         task_id = self._application_generate_entity.task_id
         tenant_id = self._application_generate_entity.app_config.tenant_id
-        features_dict = self._workflow.features_dict
+        features_dict = self._workflow_features_dict
 
         if (
             features_dict.get("text_to_speech")
@@ -242,18 +241,26 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
             if isinstance(event, QueuePingEvent):
                 yield self._ping_stream_response()
             elif isinstance(event, QueueErrorEvent):
-                err = self._handle_error(event)
+                err = self._handle_error(event=event)
                 yield self._error_to_stream_response(err)
                 break
             elif isinstance(event, QueueWorkflowStartedEvent):
                 # override graph runtime state
                 graph_runtime_state = event.graph_runtime_state
 
-                # init workflow run
-                workflow_run = self._handle_workflow_run_start()
-                yield self._workflow_start_to_stream_response(
-                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
-                )
+                with Session(db.engine) as session:
+                    # init workflow run
+                    workflow_run = self._handle_workflow_run_start(
+                        session=session,
+                        workflow_id=self._workflow_id,
+                        user_id=self._user_id,
+                        created_by_role=self._created_by_role,
+                    )
+                    start_resp = self._workflow_start_to_stream_response(
+                        session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
+                    )
+                    session.commit()
+                yield start_resp
             elif isinstance(
                 event,
                 QueueNodeRetryEvent,
@@ -350,22 +357,28 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                 if not graph_runtime_state:
                     raise ValueError("graph runtime state not initialized.")
 
-                workflow_run = self._handle_workflow_run_success(
-                    workflow_run=workflow_run,
-                    start_at=graph_runtime_state.start_at,
-                    total_tokens=graph_runtime_state.total_tokens,
-                    total_steps=graph_runtime_state.node_run_steps,
-                    outputs=event.outputs,
-                    conversation_id=None,
-                    trace_manager=trace_manager,
-                )
-
-                # save workflow app log
-                self._save_workflow_app_log(workflow_run)
-
-                yield self._workflow_finish_to_stream_response(
-                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
-                )
+                with Session(db.engine) as session:
+                    workflow_run = self._handle_workflow_run_success(
+                        session=session,
+                        workflow_run=workflow_run,
+                        start_at=graph_runtime_state.start_at,
+                        total_tokens=graph_runtime_state.total_tokens,
+                        total_steps=graph_runtime_state.node_run_steps,
+                        outputs=event.outputs,
+                        conversation_id=None,
+                        trace_manager=trace_manager,
+                    )
+
+                    # save workflow app log
+                    self._save_workflow_app_log(session=session, workflow_run=workflow_run)
+
+                    workflow_finish_resp = self._workflow_finish_to_stream_response(
+                        session=session,
+                        task_id=self._application_generate_entity.task_id,
+                        workflow_run=workflow_run,
+                    )
+                    session.commit()
+                yield workflow_finish_resp
             elif isinstance(event, QueueWorkflowPartialSuccessEvent):
                 if not workflow_run:
                     raise ValueError("workflow run not initialized.")
@@ -373,49 +386,58 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                 if not graph_runtime_state:
                     raise ValueError("graph runtime state not initialized.")
 
-                workflow_run = self._handle_workflow_run_partial_success(
-                    workflow_run=workflow_run,
-                    start_at=graph_runtime_state.start_at,
-                    total_tokens=graph_runtime_state.total_tokens,
-                    total_steps=graph_runtime_state.node_run_steps,
-                    outputs=event.outputs,
-                    exceptions_count=event.exceptions_count,
-                    conversation_id=None,
-                    trace_manager=trace_manager,
-                )
-
-                # save workflow app log
-                self._save_workflow_app_log(workflow_run)
-
-                yield self._workflow_finish_to_stream_response(
-                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
-                )
+                with Session(db.engine) as session:
+                    workflow_run = self._handle_workflow_run_partial_success(
+                        session=session,
+                        workflow_run=workflow_run,
+                        start_at=graph_runtime_state.start_at,
+                        total_tokens=graph_runtime_state.total_tokens,
+                        total_steps=graph_runtime_state.node_run_steps,
+                        outputs=event.outputs,
+                        exceptions_count=event.exceptions_count,
+                        conversation_id=None,
+                        trace_manager=trace_manager,
+                    )
+
+                    # save workflow app log
+                    self._save_workflow_app_log(session=session, workflow_run=workflow_run)
+
+                    workflow_finish_resp = self._workflow_finish_to_stream_response(
+                        session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
+                    )
+                    session.commit()
+
+                yield workflow_finish_resp
             elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
                 if not workflow_run:
                     raise ValueError("workflow run not initialized.")
 
                 if not graph_runtime_state:
                     raise ValueError("graph runtime state not initialized.")
-                workflow_run = self._handle_workflow_run_failed(
-                    workflow_run=workflow_run,
-                    start_at=graph_runtime_state.start_at,
-                    total_tokens=graph_runtime_state.total_tokens,
-                    total_steps=graph_runtime_state.node_run_steps,
-                    status=WorkflowRunStatus.FAILED
-                    if isinstance(event, QueueWorkflowFailedEvent)
-                    else WorkflowRunStatus.STOPPED,
-                    error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
-                    conversation_id=None,
-                    trace_manager=trace_manager,
-                    exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
-                )
-
-                # save workflow app log
-                self._save_workflow_app_log(workflow_run)
-
-                yield self._workflow_finish_to_stream_response(
-                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
-                )
+                with Session(db.engine) as session:
+                    workflow_run = self._handle_workflow_run_failed(
+                        session=session,
+                        workflow_run=workflow_run,
+                        start_at=graph_runtime_state.start_at,
+                        total_tokens=graph_runtime_state.total_tokens,
+                        total_steps=graph_runtime_state.node_run_steps,
+                        status=WorkflowRunStatus.FAILED
+                        if isinstance(event, QueueWorkflowFailedEvent)
+                        else WorkflowRunStatus.STOPPED,
+                        error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
+                        conversation_id=None,
+                        trace_manager=trace_manager,
+                        exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
+                    )
+
+                    # save workflow app log
+                    self._save_workflow_app_log(session=session, workflow_run=workflow_run)
+
+                    workflow_finish_resp = self._workflow_finish_to_stream_response(
+                        session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
+                    )
+                    session.commit()
+                yield workflow_finish_resp
             elif isinstance(event, QueueTextChunkEvent):
                 delta_text = event.text
                 if delta_text is None:
@@ -435,7 +457,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
         if tts_publisher:
             tts_publisher.publish(None)
 
-    def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
+    def _save_workflow_app_log(self, *, session: Session, workflow_run: WorkflowRun) -> None:
         """
         Save workflow app log.
         :return:
@@ -457,12 +479,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
         workflow_app_log.workflow_id = workflow_run.workflow_id
         workflow_app_log.workflow_run_id = workflow_run.id
         workflow_app_log.created_from = created_from.value
-        workflow_app_log.created_by_role = "account" if isinstance(self._user, Account) else "end_user"
-        workflow_app_log.created_by = self._user.id
+        workflow_app_log.created_by_role = self._created_by_role
+        workflow_app_log.created_by = self._user_id
 
-        db.session.add(workflow_app_log)
-        db.session.commit()
-        db.session.close()
+        session.add(workflow_app_log)
 
     def _text_chunk_to_stream_response(
         self, text: str, from_variable_selector: Optional[list[str]] = None

+ 15 - 21
api/core/app/task_pipeline/based_generate_task_pipeline.py

@@ -1,6 +1,9 @@
 import logging
 import time
-from typing import Optional, Union
+from typing import Optional
+
+from sqlalchemy import select
+from sqlalchemy.orm import Session
 
 from core.app.apps.base_app_queue_manager import AppQueueManager
 from core.app.entities.app_invoke_entities import (
@@ -17,9 +20,7 @@ from core.app.entities.task_entities import (
 from core.errors.error import QuotaExceededError
 from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
 from core.moderation.output_moderation import ModerationRule, OutputModeration
-from extensions.ext_database import db
-from models.account import Account
-from models.model import EndUser, Message
+from models.model import Message
 
 logger = logging.getLogger(__name__)
 
@@ -36,7 +37,6 @@ class BasedGenerateTaskPipeline:
         self,
         application_generate_entity: AppGenerateEntity,
         queue_manager: AppQueueManager,
-        user: Union[Account, EndUser],
         stream: bool,
     ) -> None:
         """
@@ -48,18 +48,11 @@ class BasedGenerateTaskPipeline:
         """
         self._application_generate_entity = application_generate_entity
         self._queue_manager = queue_manager
-        self._user = user
         self._start_at = time.perf_counter()
         self._output_moderation_handler = self._init_output_moderation()
         self._stream = stream
 
-    def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = None):
-        """
-        Handle error event.
-        :param event: event
-        :param message: message
-        :return:
-        """
+    def _handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""):
         logger.debug("error: %s", event.error)
         e = event.error
         err: Exception
@@ -71,16 +64,17 @@ class BasedGenerateTaskPipeline:
         else:
             err = Exception(e.description if getattr(e, "description", None) is not None else str(e))
 
-        if message:
-            refetch_message = db.session.query(Message).filter(Message.id == message.id).first()
-
-            if refetch_message:
-                err_desc = self._error_to_desc(err)
-                refetch_message.status = "error"
-                refetch_message.error = err_desc
+        if not message_id or not session:
+            return err
 
-                db.session.commit()
+        stmt = select(Message).where(Message.id == message_id)
+        message = session.scalar(stmt)
+        if not message:
+            return err
 
+        err_desc = self._error_to_desc(err)
+        message.status = "error"
+        message.error = err_desc
         return err
 
     def _error_to_desc(self, e: Exception) -> str:

+ 70 - 75
api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py

@@ -5,6 +5,9 @@ from collections.abc import Generator
 from threading import Thread
 from typing import Optional, Union, cast
 
+from sqlalchemy import select
+from sqlalchemy.orm import Session
+
 from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
 from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
 from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@@ -55,8 +58,7 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil
 from core.prompt.utils.prompt_template_parser import PromptTemplateParser
 from events.message_event import message_was_created
 from extensions.ext_database import db
-from models.account import Account
-from models.model import AppMode, Conversation, EndUser, Message, MessageAgentThought
+from models.model import AppMode, Conversation, Message, MessageAgentThought
 
 logger = logging.getLogger(__name__)
 
@@ -77,23 +79,21 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
         queue_manager: AppQueueManager,
         conversation: Conversation,
         message: Message,
-        user: Union[Account, EndUser],
         stream: bool,
     ) -> None:
-        """
-        Initialize GenerateTaskPipeline.
-        :param application_generate_entity: application generate entity
-        :param queue_manager: queue manager
-        :param conversation: conversation
-        :param message: message
-        :param user: user
-        :param stream: stream
-        """
-        super().__init__(application_generate_entity, queue_manager, user, stream)
+        super().__init__(
+            application_generate_entity=application_generate_entity,
+            queue_manager=queue_manager,
+            stream=stream,
+        )
         self._model_config = application_generate_entity.model_conf
         self._app_config = application_generate_entity.app_config
-        self._conversation = conversation
-        self._message = message
+
+        self._conversation_id = conversation.id
+        self._conversation_mode = conversation.mode
+
+        self._message_id = message.id
+        self._message_created_at = int(message.created_at.timestamp())
 
         self._task_state = EasyUITaskState(
             llm_result=LLMResult(
@@ -113,18 +113,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
         CompletionAppBlockingResponse,
         Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None],
     ]:
-        """
-        Process generate task pipeline.
-        :return:
-        """
-        db.session.refresh(self._conversation)
-        db.session.refresh(self._message)
-        db.session.close()
-
         if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
             # start generate conversation name thread
             self._conversation_name_generate_thread = self._generate_conversation_name(
-                self._conversation, self._application_generate_entity.query or ""
+                conversation_id=self._conversation_id, query=self._application_generate_entity.query or ""
             )
 
         generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
@@ -148,15 +140,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
                 if self._task_state.metadata:
                     extras["metadata"] = self._task_state.metadata
                 response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
-                if self._conversation.mode == AppMode.COMPLETION.value:
+                if self._conversation_mode == AppMode.COMPLETION.value:
                     response = CompletionAppBlockingResponse(
                         task_id=self._application_generate_entity.task_id,
                         data=CompletionAppBlockingResponse.Data(
-                            id=self._message.id,
-                            mode=self._conversation.mode,
-                            message_id=self._message.id,
+                            id=self._message_id,
+                            mode=self._conversation_mode,
+                            message_id=self._message_id,
                             answer=cast(str, self._task_state.llm_result.message.content),
-                            created_at=int(self._message.created_at.timestamp()),
+                            created_at=self._message_created_at,
                             **extras,
                         ),
                     )
@@ -164,12 +156,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
                     response = ChatbotAppBlockingResponse(
                         task_id=self._application_generate_entity.task_id,
                         data=ChatbotAppBlockingResponse.Data(
-                            id=self._message.id,
-                            mode=self._conversation.mode,
-                            conversation_id=self._conversation.id,
-                            message_id=self._message.id,
+                            id=self._message_id,
+                            mode=self._conversation_mode,
+                            conversation_id=self._conversation_id,
+                            message_id=self._message_id,
                             answer=cast(str, self._task_state.llm_result.message.content),
-                            created_at=int(self._message.created_at.timestamp()),
+                            created_at=self._message_created_at,
                             **extras,
                         ),
                     )
@@ -190,15 +182,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
         for stream_response in generator:
             if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
                 yield CompletionAppStreamResponse(
-                    message_id=self._message.id,
-                    created_at=int(self._message.created_at.timestamp()),
+                    message_id=self._message_id,
+                    created_at=self._message_created_at,
                     stream_response=stream_response,
                 )
             else:
                 yield ChatbotAppStreamResponse(
-                    conversation_id=self._conversation.id,
-                    message_id=self._message.id,
-                    created_at=int(self._message.created_at.timestamp()),
+                    conversation_id=self._conversation_id,
+                    message_id=self._message_id,
+                    created_at=self._message_created_at,
                     stream_response=stream_response,
                 )
 
@@ -265,7 +257,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
             event = message.event
 
             if isinstance(event, QueueErrorEvent):
-                err = self._handle_error(event, self._message)
+                with Session(db.engine) as session:
+                    err = self._handle_error(event=event, session=session, message_id=self._message_id)
+                    session.commit()
                 yield self._error_to_stream_response(err)
                 break
             elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
@@ -283,10 +277,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
                     self._task_state.llm_result.message.content = output_moderation_answer
                     yield self._message_replace_to_stream_response(answer=output_moderation_answer)
 
-                # Save message
-                self._save_message(trace_manager)
-
-                yield self._message_end_to_stream_response()
+                with Session(db.engine) as session:
+                    # Save message
+                    self._save_message(session=session, trace_manager=trace_manager)
+                    session.commit()
+                message_end_resp = self._message_end_to_stream_response()
+                yield message_end_resp
             elif isinstance(event, QueueRetrieverResourcesEvent):
                 self._handle_retriever_resources(event)
             elif isinstance(event, QueueAnnotationReplyEvent):
@@ -320,9 +316,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
                 self._task_state.llm_result.message.content = current_content
 
                 if isinstance(event, QueueLLMChunkEvent):
-                    yield self._message_to_stream_response(cast(str, delta_text), self._message.id)
+                    yield self._message_to_stream_response(
+                        answer=cast(str, delta_text),
+                        message_id=self._message_id,
+                    )
                 else:
-                    yield self._agent_message_to_stream_response(cast(str, delta_text), self._message.id)
+                    yield self._agent_message_to_stream_response(
+                        answer=cast(str, delta_text),
+                        message_id=self._message_id,
+                    )
             elif isinstance(event, QueueMessageReplaceEvent):
                 yield self._message_replace_to_stream_response(answer=event.text)
             elif isinstance(event, QueuePingEvent):
@@ -334,7 +336,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
         if self._conversation_name_generate_thread:
             self._conversation_name_generate_thread.join()
 
-    def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> None:
+    def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None) -> None:
         """
         Save message.
         :return:
@@ -342,53 +344,46 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
         llm_result = self._task_state.llm_result
         usage = llm_result.usage
 
-        message = db.session.query(Message).filter(Message.id == self._message.id).first()
+        message_stmt = select(Message).where(Message.id == self._message_id)
+        message = session.scalar(message_stmt)
         if not message:
-            raise Exception(f"Message {self._message.id} not found")
-        self._message = message
-        conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
+            raise ValueError(f"message {self._message_id} not found")
+        conversation_stmt = select(Conversation).where(Conversation.id == self._conversation_id)
+        conversation = session.scalar(conversation_stmt)
         if not conversation:
-            raise Exception(f"Conversation {self._conversation.id} not found")
-        self._conversation = conversation
+            raise ValueError(f"Conversation {self._conversation_id} not found")
 
-        self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
+        message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
             self._model_config.mode, self._task_state.llm_result.prompt_messages
         )
-        self._message.message_tokens = usage.prompt_tokens
-        self._message.message_unit_price = usage.prompt_unit_price
-        self._message.message_price_unit = usage.prompt_price_unit
-        self._message.answer = (
+        message.message_tokens = usage.prompt_tokens
+        message.message_unit_price = usage.prompt_unit_price
+        message.message_price_unit = usage.prompt_price_unit
+        message.answer = (
             PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip())
             if llm_result.message.content
             else ""
         )
-        self._message.answer_tokens = usage.completion_tokens
-        self._message.answer_unit_price = usage.completion_unit_price
-        self._message.answer_price_unit = usage.completion_price_unit
-        self._message.provider_response_latency = time.perf_counter() - self._start_at
-        self._message.total_price = usage.total_price
-        self._message.currency = usage.currency
-        self._message.message_metadata = (
+        message.answer_tokens = usage.completion_tokens
+        message.answer_unit_price = usage.completion_unit_price
+        message.answer_price_unit = usage.completion_price_unit
+        message.provider_response_latency = time.perf_counter() - self._start_at
+        message.total_price = usage.total_price
+        message.currency = usage.currency
+        message.message_metadata = (
             json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
         )
 
-        db.session.commit()
-
         if trace_manager:
             trace_manager.add_trace_task(
                 TraceTask(
-                    TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation.id, message_id=self._message.id
+                    TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id
                 )
             )
 
         message_was_created.send(
-            self._message,
+            message,
             application_generate_entity=self._application_generate_entity,
-            conversation=self._conversation,
-            is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT}
-            and hasattr(self._application_generate_entity, "conversation_id")
-            and self._application_generate_entity.conversation_id is None,
-            extras=self._application_generate_entity.extras,
         )
 
     def _handle_stop(self, event: QueueStopEvent) -> None:
@@ -434,7 +429,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
 
         return MessageEndStreamResponse(
             task_id=self._application_generate_entity.task_id,
-            id=self._message.id,
+            id=self._message_id,
             metadata=extras.get("metadata", {}),
         )
 

+ 2 - 2
api/core/app/task_pipeline/message_cycle_manage.py

@@ -36,7 +36,7 @@ class MessageCycleManage:
     ]
     _task_state: Union[EasyUITaskState, WorkflowTaskState]
 
-    def _generate_conversation_name(self, conversation: Conversation, query: str) -> Optional[Thread]:
+    def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
         """
         Generate conversation name.
         :param conversation: conversation
@@ -56,7 +56,7 @@ class MessageCycleManage:
                 target=self._generate_conversation_name_worker,
                 kwargs={
                     "flask_app": current_app._get_current_object(),  # type: ignore
-                    "conversation_id": conversation.id,
+                    "conversation_id": conversation_id,
                     "query": query,
                 },
             )

+ 85 - 97
api/core/app/task_pipeline/workflow_cycle_manage.py

@@ -5,6 +5,7 @@ from datetime import UTC, datetime
 from typing import Any, Optional, Union, cast
 from uuid import uuid4
 
+from sqlalchemy import func, select
 from sqlalchemy.orm import Session
 
 from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
@@ -63,27 +64,34 @@ from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError
 
 class WorkflowCycleManage:
     _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
-    _workflow: Workflow
-    _user: Union[Account, EndUser]
     _task_state: WorkflowTaskState
     _workflow_system_variables: dict[SystemVariableKey, Any]
     _wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
 
-    def _handle_workflow_run_start(self) -> WorkflowRun:
-        max_sequence = (
-            db.session.query(db.func.max(WorkflowRun.sequence_number))
-            .filter(WorkflowRun.tenant_id == self._workflow.tenant_id)
-            .filter(WorkflowRun.app_id == self._workflow.app_id)
-            .scalar()
-            or 0
+    def _handle_workflow_run_start(
+        self,
+        *,
+        session: Session,
+        workflow_id: str,
+        user_id: str,
+        created_by_role: CreatedByRole,
+    ) -> WorkflowRun:
+        workflow_stmt = select(Workflow).where(Workflow.id == workflow_id)
+        workflow = session.scalar(workflow_stmt)
+        if not workflow:
+            raise ValueError(f"Workflow not found: {workflow_id}")
+
+        max_sequence_stmt = select(func.max(WorkflowRun.sequence_number)).where(
+            WorkflowRun.tenant_id == workflow.tenant_id,
+            WorkflowRun.app_id == workflow.app_id,
         )
+        max_sequence = session.scalar(max_sequence_stmt) or 0
         new_sequence_number = max_sequence + 1
 
         inputs = {**self._application_generate_entity.inputs}
         for key, value in (self._workflow_system_variables or {}).items():
             if key.value == "conversation":
                 continue
-
             inputs[f"sys.{key.value}"] = value
 
         triggered_from = (
@@ -96,33 +104,32 @@ class WorkflowCycleManage:
         inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
 
         # init workflow run
-        with Session(db.engine, expire_on_commit=False) as session:
-            workflow_run = WorkflowRun()
-            system_id = self._workflow_system_variables[SystemVariableKey.WORKFLOW_RUN_ID]
-            workflow_run.id = system_id or str(uuid4())
-            workflow_run.tenant_id = self._workflow.tenant_id
-            workflow_run.app_id = self._workflow.app_id
-            workflow_run.sequence_number = new_sequence_number
-            workflow_run.workflow_id = self._workflow.id
-            workflow_run.type = self._workflow.type
-            workflow_run.triggered_from = triggered_from.value
-            workflow_run.version = self._workflow.version
-            workflow_run.graph = self._workflow.graph
-            workflow_run.inputs = json.dumps(inputs)
-            workflow_run.status = WorkflowRunStatus.RUNNING
-            workflow_run.created_by_role = (
-                CreatedByRole.ACCOUNT if isinstance(self._user, Account) else CreatedByRole.END_USER
-            )
-            workflow_run.created_by = self._user.id
-            workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None)
-
-            session.add(workflow_run)
-            session.commit()
+        workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID, uuid4()))
+
+        workflow_run = WorkflowRun()
+        workflow_run.id = workflow_run_id
+        workflow_run.tenant_id = workflow.tenant_id
+        workflow_run.app_id = workflow.app_id
+        workflow_run.sequence_number = new_sequence_number
+        workflow_run.workflow_id = workflow.id
+        workflow_run.type = workflow.type
+        workflow_run.triggered_from = triggered_from.value
+        workflow_run.version = workflow.version
+        workflow_run.graph = workflow.graph
+        workflow_run.inputs = json.dumps(inputs)
+        workflow_run.status = WorkflowRunStatus.RUNNING
+        workflow_run.created_by_role = created_by_role
+        workflow_run.created_by = user_id
+        workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None)
+
+        session.add(workflow_run)
 
         return workflow_run
 
     def _handle_workflow_run_success(
         self,
+        *,
+        session: Session,
         workflow_run: WorkflowRun,
         start_at: float,
         total_tokens: int,
@@ -141,7 +148,7 @@ class WorkflowCycleManage:
         :param conversation_id: conversation id
         :return:
         """
-        workflow_run = self._refetch_workflow_run(workflow_run.id)
+        workflow_run = self._refetch_workflow_run(session=session, workflow_run_id=workflow_run.id)
 
         outputs = WorkflowEntry.handle_special_values(outputs)
 
@@ -152,9 +159,6 @@ class WorkflowCycleManage:
         workflow_run.total_steps = total_steps
         workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
 
-        db.session.commit()
-        db.session.refresh(workflow_run)
-
         if trace_manager:
             trace_manager.add_trace_task(
                 TraceTask(
@@ -165,12 +169,12 @@ class WorkflowCycleManage:
                 )
             )
 
-        db.session.close()
-
         return workflow_run
 
     def _handle_workflow_run_partial_success(
         self,
+        *,
+        session: Session,
         workflow_run: WorkflowRun,
         start_at: float,
         total_tokens: int,
@@ -190,7 +194,7 @@ class WorkflowCycleManage:
         :param conversation_id: conversation id
         :return:
         """
-        workflow_run = self._refetch_workflow_run(workflow_run.id)
+        workflow_run = self._refetch_workflow_run(session=session, workflow_run_id=workflow_run.id)
 
         outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
 
@@ -201,8 +205,6 @@ class WorkflowCycleManage:
         workflow_run.total_steps = total_steps
         workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
         workflow_run.exceptions_count = exceptions_count
-        db.session.commit()
-        db.session.refresh(workflow_run)
 
         if trace_manager:
             trace_manager.add_trace_task(
@@ -214,12 +216,12 @@ class WorkflowCycleManage:
                 )
             )
 
-        db.session.close()
-
         return workflow_run
 
     def _handle_workflow_run_failed(
         self,
+        *,
+        session: Session,
         workflow_run: WorkflowRun,
         start_at: float,
         total_tokens: int,
@@ -240,7 +242,7 @@ class WorkflowCycleManage:
         :param error: error message
         :return:
         """
-        workflow_run = self._refetch_workflow_run(workflow_run.id)
+        workflow_run = self._refetch_workflow_run(session=session, workflow_run_id=workflow_run.id)
 
         workflow_run.status = status.value
         workflow_run.error = error
@@ -249,21 +251,18 @@ class WorkflowCycleManage:
         workflow_run.total_steps = total_steps
         workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
         workflow_run.exceptions_count = exceptions_count
-        db.session.commit()
 
-        running_workflow_node_executions = (
-            db.session.query(WorkflowNodeExecution)
-            .filter(
-                WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
-                WorkflowNodeExecution.app_id == workflow_run.app_id,
-                WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
-                WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
-                WorkflowNodeExecution.workflow_run_id == workflow_run.id,
-                WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
-            )
-            .all()
+        stmt = select(WorkflowNodeExecution).where(
+            WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
+            WorkflowNodeExecution.app_id == workflow_run.app_id,
+            WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
+            WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
+            WorkflowNodeExecution.workflow_run_id == workflow_run.id,
+            WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
         )
 
+        running_workflow_node_executions = session.scalars(stmt).all()
+
         for workflow_node_execution in running_workflow_node_executions:
             workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
             workflow_node_execution.error = error
@@ -271,13 +270,6 @@ class WorkflowCycleManage:
             workflow_node_execution.elapsed_time = (
                 workflow_node_execution.finished_at - workflow_node_execution.created_at
             ).total_seconds()
-            db.session.commit()
-
-        db.session.close()
-
-        # with Session(db.engine, expire_on_commit=False) as session:
-        #     session.add(workflow_run)
-        #     session.refresh(workflow_run)
 
         if trace_manager:
             trace_manager.add_trace_task(
@@ -485,14 +477,14 @@ class WorkflowCycleManage:
     #################################################
 
     def _workflow_start_to_stream_response(
-        self, task_id: str, workflow_run: WorkflowRun
+        self,
+        *,
+        session: Session,
+        task_id: str,
+        workflow_run: WorkflowRun,
     ) -> WorkflowStartStreamResponse:
-        """
-        Workflow start to stream response.
-        :param task_id: task id
-        :param workflow_run: workflow run
-        :return:
-        """
+        # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
+        _ = session
         return WorkflowStartStreamResponse(
             task_id=task_id,
             workflow_run_id=workflow_run.id,
@@ -506,36 +498,32 @@ class WorkflowCycleManage:
         )
 
     def _workflow_finish_to_stream_response(
-        self, task_id: str, workflow_run: WorkflowRun
+        self,
+        *,
+        session: Session,
+        task_id: str,
+        workflow_run: WorkflowRun,
     ) -> WorkflowFinishStreamResponse:
-        """
-        Workflow finish to stream response.
-        :param task_id: task id
-        :param workflow_run: workflow run
-        :return:
-        """
-        # Attach WorkflowRun to an active session so "created_by_role" can be accessed.
-        workflow_run = db.session.merge(workflow_run)
-
-        # Refresh to ensure any expired attributes are fully loaded
-        db.session.refresh(workflow_run)
-
         created_by = None
-        if workflow_run.created_by_role == CreatedByRole.ACCOUNT.value:
-            created_by_account = workflow_run.created_by_account
-            if created_by_account:
+        if workflow_run.created_by_role == CreatedByRole.ACCOUNT:
+            stmt = select(Account).where(Account.id == workflow_run.created_by)
+            account = session.scalar(stmt)
+            if account:
                 created_by = {
-                    "id": created_by_account.id,
-                    "name": created_by_account.name,
-                    "email": created_by_account.email,
+                    "id": account.id,
+                    "name": account.name,
+                    "email": account.email,
                 }
-        else:
-            created_by_end_user = workflow_run.created_by_end_user
-            if created_by_end_user:
+        elif workflow_run.created_by_role == CreatedByRole.END_USER:
+            stmt = select(EndUser).where(EndUser.id == workflow_run.created_by)
+            end_user = session.scalar(stmt)
+            if end_user:
                 created_by = {
-                    "id": created_by_end_user.id,
-                    "user": created_by_end_user.session_id,
+                    "id": end_user.id,
+                    "user": end_user.session_id,
                 }
+        else:
+            raise NotImplementedError(f"unknown created_by_role: {workflow_run.created_by_role}")
 
         return WorkflowFinishStreamResponse(
             task_id=task_id,
@@ -895,14 +883,14 @@ class WorkflowCycleManage:
 
         return None
 
-    def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
+    def _refetch_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
         """
         Refetch workflow run
         :param workflow_run_id: workflow run id
         :return:
         """
-        workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
-
+        stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
+        workflow_run = session.scalar(stmt)
         if not workflow_run:
             raise WorkflowRunNotFoundError(workflow_run_id)
 

+ 103 - 83
api/core/ops/ops_trace_manager.py

@@ -9,6 +9,8 @@ from typing import Any, Optional, Union
 from uuid import UUID, uuid4
 
 from flask import current_app
+from sqlalchemy import select
+from sqlalchemy.orm import Session
 
 from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
 from core.ops.entities.config_entity import (
@@ -329,15 +331,15 @@ class TraceTask:
     ):
         self.trace_type = trace_type
         self.message_id = message_id
-        self.workflow_run = workflow_run
+        self.workflow_run_id = workflow_run.id if workflow_run else None
         self.conversation_id = conversation_id
         self.user_id = user_id
         self.timer = timer
-        self.kwargs = kwargs
         self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
-
         self.app_id = None
 
+        self.kwargs = kwargs
+
     def execute(self):
         return self.preprocess()
 
@@ -345,19 +347,23 @@ class TraceTask:
         preprocess_map = {
             TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs),
             TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace(
-                self.workflow_run, self.conversation_id, self.user_id
+                workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id
+            ),
+            TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id),
+            TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(
+                message_id=self.message_id, timer=self.timer, **self.kwargs
             ),
-            TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(self.message_id),
-            TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(self.message_id, self.timer, **self.kwargs),
             TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace(
-                self.message_id, self.timer, **self.kwargs
+                message_id=self.message_id, timer=self.timer, **self.kwargs
             ),
             TraceTaskName.DATASET_RETRIEVAL_TRACE: lambda: self.dataset_retrieval_trace(
-                self.message_id, self.timer, **self.kwargs
+                message_id=self.message_id, timer=self.timer, **self.kwargs
+            ),
+            TraceTaskName.TOOL_TRACE: lambda: self.tool_trace(
+                message_id=self.message_id, timer=self.timer, **self.kwargs
             ),
-            TraceTaskName.TOOL_TRACE: lambda: self.tool_trace(self.message_id, self.timer, **self.kwargs),
             TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace(
-                self.conversation_id, self.timer, **self.kwargs
+                conversation_id=self.conversation_id, timer=self.timer, **self.kwargs
             ),
         }
 
@@ -367,86 +373,100 @@ class TraceTask:
     def conversation_trace(self, **kwargs):
         return kwargs
 
-    def workflow_trace(self, workflow_run: WorkflowRun | None, conversation_id, user_id):
-        if not workflow_run:
-            raise ValueError("Workflow run not found")
-
-        db.session.merge(workflow_run)
-        db.session.refresh(workflow_run)
-
-        workflow_id = workflow_run.workflow_id
-        tenant_id = workflow_run.tenant_id
-        workflow_run_id = workflow_run.id
-        workflow_run_elapsed_time = workflow_run.elapsed_time
-        workflow_run_status = workflow_run.status
-        workflow_run_inputs = workflow_run.inputs_dict
-        workflow_run_outputs = workflow_run.outputs_dict
-        workflow_run_version = workflow_run.version
-        error = workflow_run.error or ""
-
-        total_tokens = workflow_run.total_tokens
-
-        file_list = workflow_run_inputs.get("sys.file") or []
-        query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
-
-        # get workflow_app_log_id
-        workflow_app_log_data = (
-            db.session.query(WorkflowAppLog)
-            .filter_by(tenant_id=tenant_id, app_id=workflow_run.app_id, workflow_run_id=workflow_run.id)
-            .first()
-        )
-        workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
-        # get message_id
-        message_data = (
-            db.session.query(Message.id)
-            .filter_by(conversation_id=conversation_id, workflow_run_id=workflow_run_id)
-            .first()
-        )
-        message_id = str(message_data.id) if message_data else None
-
-        metadata = {
-            "workflow_id": workflow_id,
-            "conversation_id": conversation_id,
-            "workflow_run_id": workflow_run_id,
-            "tenant_id": tenant_id,
-            "elapsed_time": workflow_run_elapsed_time,
-            "status": workflow_run_status,
-            "version": workflow_run_version,
-            "total_tokens": total_tokens,
-            "file_list": file_list,
-            "triggered_form": workflow_run.triggered_from,
-            "user_id": user_id,
-        }
+    def workflow_trace(
+        self,
+        *,
+        workflow_run_id: str | None,
+        conversation_id: str | None,
+        user_id: str | None,
+    ):
+        if not workflow_run_id:
+            return {}
 
-        workflow_trace_info = WorkflowTraceInfo(
-            workflow_data=workflow_run.to_dict(),
-            conversation_id=conversation_id,
-            workflow_id=workflow_id,
-            tenant_id=tenant_id,
-            workflow_run_id=workflow_run_id,
-            workflow_run_elapsed_time=workflow_run_elapsed_time,
-            workflow_run_status=workflow_run_status,
-            workflow_run_inputs=workflow_run_inputs,
-            workflow_run_outputs=workflow_run_outputs,
-            workflow_run_version=workflow_run_version,
-            error=error,
-            total_tokens=total_tokens,
-            file_list=file_list,
-            query=query,
-            metadata=metadata,
-            workflow_app_log_id=workflow_app_log_id,
-            message_id=message_id,
-            start_time=workflow_run.created_at,
-            end_time=workflow_run.finished_at,
-        )
+        with Session(db.engine) as session:
+            workflow_run_stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
+            workflow_run = session.scalars(workflow_run_stmt).first()
+            if not workflow_run:
+                raise ValueError("Workflow run not found")
+
+            workflow_id = workflow_run.workflow_id
+            tenant_id = workflow_run.tenant_id
+            workflow_run_id = workflow_run.id
+            workflow_run_elapsed_time = workflow_run.elapsed_time
+            workflow_run_status = workflow_run.status
+            workflow_run_inputs = workflow_run.inputs_dict
+            workflow_run_outputs = workflow_run.outputs_dict
+            workflow_run_version = workflow_run.version
+            error = workflow_run.error or ""
+
+            total_tokens = workflow_run.total_tokens
+
+            file_list = workflow_run_inputs.get("sys.file") or []
+            query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
+
+            # get workflow_app_log_id
+            workflow_app_log_data_stmt = select(WorkflowAppLog.id).where(
+                WorkflowAppLog.tenant_id == tenant_id,
+                WorkflowAppLog.app_id == workflow_run.app_id,
+                WorkflowAppLog.workflow_run_id == workflow_run.id,
+            )
+            workflow_app_log_id = session.scalar(workflow_app_log_data_stmt)
+            # get message_id
+            message_id = None
+            if conversation_id:
+                message_data_stmt = select(Message.id).where(
+                    Message.conversation_id == conversation_id,
+                    Message.workflow_run_id == workflow_run_id,
+                )
+                message_id = session.scalar(message_data_stmt)
+
+            metadata = {
+                "workflow_id": workflow_id,
+                "conversation_id": conversation_id,
+                "workflow_run_id": workflow_run_id,
+                "tenant_id": tenant_id,
+                "elapsed_time": workflow_run_elapsed_time,
+                "status": workflow_run_status,
+                "version": workflow_run_version,
+                "total_tokens": total_tokens,
+                "file_list": file_list,
+                "triggered_form": workflow_run.triggered_from,
+                "user_id": user_id,
+            }
 
+            workflow_trace_info = WorkflowTraceInfo(
+                workflow_data=workflow_run.to_dict(),
+                conversation_id=conversation_id,
+                workflow_id=workflow_id,
+                tenant_id=tenant_id,
+                workflow_run_id=workflow_run_id,
+                workflow_run_elapsed_time=workflow_run_elapsed_time,
+                workflow_run_status=workflow_run_status,
+                workflow_run_inputs=workflow_run_inputs,
+                workflow_run_outputs=workflow_run_outputs,
+                workflow_run_version=workflow_run_version,
+                error=error,
+                total_tokens=total_tokens,
+                file_list=file_list,
+                query=query,
+                metadata=metadata,
+                workflow_app_log_id=workflow_app_log_id,
+                message_id=message_id,
+                start_time=workflow_run.created_at,
+                end_time=workflow_run.finished_at,
+            )
         return workflow_trace_info
 
-    def message_trace(self, message_id):
+    def message_trace(self, message_id: str | None):
+        if not message_id:
+            return {}
         message_data = get_message_data(message_id)
         if not message_data:
             return {}
-        conversation_mode = db.session.query(Conversation.mode).filter_by(id=message_data.conversation_id).first()
+        conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id)
+        conversation_mode = db.session.scalars(conversation_mode_stmt).all()
+        if not conversation_mode or len(conversation_mode) == 0:
+            return {}
         conversation_mode = conversation_mode[0]
         created_at = message_data.created_at
         inputs = message_data.message

+ 1 - 1
api/core/ops/utils.py

@@ -18,7 +18,7 @@ def filter_none_values(data: dict):
     return new_data
 
 
-def get_message_data(message_id):
+def get_message_data(message_id: str):
     return db.session.query(Message).filter(Message.id == message_id).first()
 
 

+ 2 - 1
api/models/account.py

@@ -3,6 +3,7 @@ import json
 
 from flask_login import UserMixin  # type: ignore
 from sqlalchemy import func
+from sqlalchemy.orm import Mapped, mapped_column
 
 from .engine import db
 from .types import StringUUID
@@ -20,7 +21,7 @@ class Account(UserMixin, db.Model):  # type: ignore[name-defined]
     __tablename__ = "accounts"
     __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email"))
 
-    id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+    id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
     name = db.Column(db.String(255), nullable=False)
     email = db.Column(db.String(255), nullable=False)
     password = db.Column(db.String(255), nullable=True)

+ 5 - 5
api/models/model.py

@@ -530,13 +530,13 @@ class Conversation(db.Model):  # type: ignore[name-defined]
         db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"),
     )
 
-    id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+    id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
     app_id = db.Column(StringUUID, nullable=False)
     app_model_config_id = db.Column(StringUUID, nullable=True)
     model_provider = db.Column(db.String(255), nullable=True)
     override_model_configs = db.Column(db.Text)
     model_id = db.Column(db.String(255), nullable=True)
-    mode = db.Column(db.String(255), nullable=False)
+    mode: Mapped[str] = mapped_column(db.String(255))
     name = db.Column(db.String(255), nullable=False)
     summary = db.Column(db.Text)
     _inputs: Mapped[dict] = mapped_column("inputs", db.JSON)
@@ -770,7 +770,7 @@ class Message(db.Model):  # type: ignore[name-defined]
         db.Index("message_created_at_idx", "created_at"),
     )
 
-    id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+    id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
     app_id = db.Column(StringUUID, nullable=False)
     model_provider = db.Column(db.String(255), nullable=True)
     model_id = db.Column(db.String(255), nullable=True)
@@ -797,7 +797,7 @@ class Message(db.Model):  # type: ignore[name-defined]
     from_source = db.Column(db.String(255), nullable=False)
     from_end_user_id: Mapped[Optional[str]] = db.Column(StringUUID)
     from_account_id: Mapped[Optional[str]] = db.Column(StringUUID)
-    created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+    created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp())
     updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
     agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
     workflow_run_id = db.Column(StringUUID)
@@ -1322,7 +1322,7 @@ class EndUser(UserMixin, db.Model):  # type: ignore[name-defined]
     external_user_id = db.Column(db.String(255), nullable=True)
     name = db.Column(db.String(255))
     is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
-    session_id = db.Column(db.String(255), nullable=False)
+    session_id: Mapped[str] = mapped_column()
     created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
     updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
 

+ 18 - 30
api/models/workflow.py

@@ -392,40 +392,28 @@ class WorkflowRun(db.Model):  # type: ignore[name-defined]
         db.Index("workflow_run_tenant_app_sequence_idx", "tenant_id", "app_id", "sequence_number"),
     )
 
-    id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
-    tenant_id = db.Column(StringUUID, nullable=False)
-    app_id = db.Column(StringUUID, nullable=False)
-    sequence_number = db.Column(db.Integer, nullable=False)
-    workflow_id = db.Column(StringUUID, nullable=False)
-    type = db.Column(db.String(255), nullable=False)
-    triggered_from = db.Column(db.String(255), nullable=False)
-    version = db.Column(db.String(255), nullable=False)
-    graph = db.Column(db.Text)
-    inputs = db.Column(db.Text)
-    status = db.Column(db.String(255), nullable=False)  # running, succeeded, failed, stopped, partial-succeeded
+    id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+    tenant_id: Mapped[str] = mapped_column(StringUUID)
+    app_id: Mapped[str] = mapped_column(StringUUID)
+    sequence_number: Mapped[int] = mapped_column()
+    workflow_id: Mapped[str] = mapped_column(StringUUID)
+    type: Mapped[str] = mapped_column(db.String(255))
+    triggered_from: Mapped[str] = mapped_column(db.String(255))
+    version: Mapped[str] = mapped_column(db.String(255))
+    graph: Mapped[str] = mapped_column(db.Text)
+    inputs: Mapped[str] = mapped_column(db.Text)
+    status: Mapped[str] = mapped_column(db.String(255))  # running, succeeded, failed, stopped, partial-succeeded
     outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}")
-    error = db.Column(db.Text)
+    error: Mapped[str] = mapped_column(db.Text)
     elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0"))
-    total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
+    total_tokens: Mapped[int] = mapped_column(server_default=db.text("0"))
     total_steps = db.Column(db.Integer, server_default=db.text("0"))
-    created_by_role = db.Column(db.String(255), nullable=False)  # account, end_user
+    created_by_role: Mapped[str] = mapped_column(db.String(255))  # account, end_user
     created_by = db.Column(StringUUID, nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
     finished_at = db.Column(db.DateTime)
     exceptions_count = db.Column(db.Integer, server_default=db.text("0"))
 
-    @property
-    def created_by_account(self):
-        created_by_role = CreatedByRole(self.created_by_role)
-        return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None
-
-    @property
-    def created_by_end_user(self):
-        from models.model import EndUser
-
-        created_by_role = CreatedByRole(self.created_by_role)
-        return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
-
     @property
     def graph_dict(self):
         return json.loads(self.graph) if self.graph else {}
@@ -750,11 +738,11 @@ class WorkflowAppLog(db.Model):  # type: ignore[name-defined]
         db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"),
     )
 
-    id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
-    tenant_id = db.Column(StringUUID, nullable=False)
-    app_id = db.Column(StringUUID, nullable=False)
+    id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+    tenant_id: Mapped[str] = mapped_column(StringUUID)
+    app_id: Mapped[str] = mapped_column(StringUUID)
     workflow_id = db.Column(StringUUID, nullable=False)
-    workflow_run_id = db.Column(StringUUID, nullable=False)
+    workflow_run_id: Mapped[str] = mapped_column(StringUUID)
     created_from = db.Column(db.String(255), nullable=False)
     created_by_role = db.Column(db.String(255), nullable=False)
     created_by = db.Column(StringUUID, nullable=False)