Procházet zdrojové kódy

refactor: streamline initialization of application_generate_entity and task_state in task pipeline classes (#12326)

Signed-off-by: -LAN- <laipz8200@outlook.com>
-LAN- před 3 měsíci
rodič
revize
7ed6485f86

+ 140 - 108
api/core/app/apps/advanced_chat/generate_task_pipeline.py

@@ -67,24 +67,17 @@ from models.account import Account
 from models.enums import CreatedByRole
 from models.workflow import (
     Workflow,
-    WorkflowNodeExecution,
     WorkflowRunStatus,
 )
 
 logger = logging.getLogger(__name__)
 
 
-class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage, MessageCycleManage):
+class AdvancedChatAppGenerateTaskPipeline:
     """
     AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
     """
 
-    _task_state: WorkflowTaskState
-    _application_generate_entity: AdvancedChatAppGenerateEntity
-    _workflow_system_variables: dict[SystemVariableKey, Any]
-    _wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
-    _conversation_name_generate_thread: Optional[Thread] = None
-
     def __init__(
         self,
         application_generate_entity: AdvancedChatAppGenerateEntity,
@@ -96,7 +89,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
         stream: bool,
         dialogue_count: int,
     ) -> None:
-        super().__init__(
+        self._base_task_pipeline = BasedGenerateTaskPipeline(
             application_generate_entity=application_generate_entity,
             queue_manager=queue_manager,
             stream=stream,
@@ -113,32 +106,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
         else:
             raise NotImplementedError(f"User type not supported: {type(user)}")
 
+        self._workflow_cycle_manager = WorkflowCycleManage(
+            application_generate_entity=application_generate_entity,
+            workflow_system_variables={
+                SystemVariableKey.QUERY: message.query,
+                SystemVariableKey.FILES: application_generate_entity.files,
+                SystemVariableKey.CONVERSATION_ID: conversation.id,
+                SystemVariableKey.USER_ID: user_session_id,
+                SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
+                SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
+                SystemVariableKey.WORKFLOW_ID: workflow.id,
+                SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
+            },
+        )
+
+        self._task_state = WorkflowTaskState()
+        self._message_cycle_manager = MessageCycleManage(
+            application_generate_entity=application_generate_entity, task_state=self._task_state
+        )
+
+        self._application_generate_entity = application_generate_entity
         self._workflow_id = workflow.id
         self._workflow_features_dict = workflow.features_dict
-
         self._conversation_id = conversation.id
         self._conversation_mode = conversation.mode
-
         self._message_id = message.id
         self._message_created_at = int(message.created_at.timestamp())
-
-        self._workflow_system_variables = {
-            SystemVariableKey.QUERY: message.query,
-            SystemVariableKey.FILES: application_generate_entity.files,
-            SystemVariableKey.CONVERSATION_ID: conversation.id,
-            SystemVariableKey.USER_ID: user_session_id,
-            SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
-            SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
-            SystemVariableKey.WORKFLOW_ID: workflow.id,
-            SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
-        }
-
-        self._task_state = WorkflowTaskState()
-        self._wip_workflow_node_executions = {}
-
-        self._conversation_name_generate_thread = None
+        self._conversation_name_generate_thread: Thread | None = None
         self._recorded_files: list[Mapping[str, Any]] = []
-        self._workflow_run_id = ""
+        self._workflow_run_id: str = ""
 
     def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
         """
@@ -146,13 +142,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
         :return:
         """
         # start generate conversation name thread
-        self._conversation_name_generate_thread = self._generate_conversation_name(
+        self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name(
             conversation_id=self._conversation_id, query=self._application_generate_entity.query
         )
 
         generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
 
-        if self._stream:
+        if self._base_task_pipeline._stream:
             return self._to_stream_response(generator)
         else:
             return self._to_blocking_response(generator)
@@ -269,24 +265,26 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
         # init fake graph runtime state
         graph_runtime_state: Optional[GraphRuntimeState] = None
 
-        for queue_message in self._queue_manager.listen():
+        for queue_message in self._base_task_pipeline._queue_manager.listen():
             event = queue_message.event
 
             if isinstance(event, QueuePingEvent):
-                yield self._ping_stream_response()
+                yield self._base_task_pipeline._ping_stream_response()
             elif isinstance(event, QueueErrorEvent):
-                with Session(db.engine) as session:
-                    err = self._handle_error(event=event, session=session, message_id=self._message_id)
+                with Session(db.engine, expire_on_commit=False) as session:
+                    err = self._base_task_pipeline._handle_error(
+                        event=event, session=session, message_id=self._message_id
+                    )
                     session.commit()
-                yield self._error_to_stream_response(err)
+                yield self._base_task_pipeline._error_to_stream_response(err)
                 break
             elif isinstance(event, QueueWorkflowStartedEvent):
                 # override graph runtime state
                 graph_runtime_state = event.graph_runtime_state
 
-                with Session(db.engine) as session:
+                with Session(db.engine, expire_on_commit=False) as session:
                     # init workflow run
-                    workflow_run = self._handle_workflow_run_start(
+                    workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
                         session=session,
                         workflow_id=self._workflow_id,
                         user_id=self._user_id,
@@ -297,7 +295,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                     if not message:
                         raise ValueError(f"Message not found: {self._message_id}")
                     message.workflow_run_id = workflow_run.id
-                    workflow_start_resp = self._workflow_start_to_stream_response(
+                    workflow_start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
                         session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
                     )
                     session.commit()
@@ -310,12 +308,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
 
-                with Session(db.engine) as session:
-                    workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
-                    workflow_node_execution = self._handle_workflow_node_execution_retried(
+                with Session(db.engine, expire_on_commit=False) as session:
+                    workflow_run = self._workflow_cycle_manager._get_workflow_run(
+                        session=session, workflow_run_id=self._workflow_run_id
+                    )
+                    workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
                         session=session, workflow_run=workflow_run, event=event
                     )
-                    node_retry_resp = self._workflow_node_retry_to_stream_response(
+                    node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
                         session=session,
                         event=event,
                         task_id=self._application_generate_entity.task_id,
@@ -329,13 +329,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
 
-                with Session(db.engine) as session:
-                    workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
-                    workflow_node_execution = self._handle_node_execution_start(
+                with Session(db.engine, expire_on_commit=False) as session:
+                    workflow_run = self._workflow_cycle_manager._get_workflow_run(
+                        session=session, workflow_run_id=self._workflow_run_id
+                    )
+                    workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
                         session=session, workflow_run=workflow_run, event=event
                     )
 
-                    node_start_resp = self._workflow_node_start_to_stream_response(
+                    node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
                         session=session,
                         event=event,
                         task_id=self._application_generate_entity.task_id,
@@ -348,12 +350,16 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
             elif isinstance(event, QueueNodeSucceededEvent):
                 # Record files if it's an answer node or end node
                 if event.node_type in [NodeType.ANSWER, NodeType.END]:
-                    self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
+                    self._recorded_files.extend(
+                        self._workflow_cycle_manager._fetch_files_from_node_outputs(event.outputs or {})
+                    )
 
-                with Session(db.engine) as session:
-                    workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event)
+                with Session(db.engine, expire_on_commit=False) as session:
+                    workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
+                        session=session, event=event
+                    )
 
-                    node_finish_resp = self._workflow_node_finish_to_stream_response(
+                    node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
                         session=session,
                         event=event,
                         task_id=self._application_generate_entity.task_id,
@@ -364,10 +370,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 if node_finish_resp:
                     yield node_finish_resp
             elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
-                with Session(db.engine) as session:
-                    workflow_node_execution = self._handle_workflow_node_execution_failed(session=session, event=event)
+                with Session(db.engine, expire_on_commit=False) as session:
+                    workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
+                        session=session, event=event
+                    )
 
-                    node_finish_resp = self._workflow_node_finish_to_stream_response(
+                    node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
                         session=session,
                         event=event,
                         task_id=self._application_generate_entity.task_id,
@@ -381,13 +389,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
 
-                with Session(db.engine) as session:
-                    workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
-                    parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response(
-                        session=session,
-                        task_id=self._application_generate_entity.task_id,
-                        workflow_run=workflow_run,
-                        event=event,
+                with Session(db.engine, expire_on_commit=False) as session:
+                    workflow_run = self._workflow_cycle_manager._get_workflow_run(
+                        session=session, workflow_run_id=self._workflow_run_id
+                    )
+                    parallel_start_resp = (
+                        self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
+                            session=session,
+                            task_id=self._application_generate_entity.task_id,
+                            workflow_run=workflow_run,
+                            event=event,
+                        )
                     )
 
                 yield parallel_start_resp
@@ -395,13 +407,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
 
-                with Session(db.engine) as session:
-                    workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
-                    parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response(
-                        session=session,
-                        task_id=self._application_generate_entity.task_id,
-                        workflow_run=workflow_run,
-                        event=event,
+                with Session(db.engine, expire_on_commit=False) as session:
+                    workflow_run = self._workflow_cycle_manager._get_workflow_run(
+                        session=session, workflow_run_id=self._workflow_run_id
+                    )
+                    parallel_finish_resp = (
+                        self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
+                            session=session,
+                            task_id=self._application_generate_entity.task_id,
+                            workflow_run=workflow_run,
+                            event=event,
+                        )
                     )
 
                 yield parallel_finish_resp
@@ -409,9 +425,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
 
-                with Session(db.engine) as session:
-                    workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
-                    iter_start_resp = self._workflow_iteration_start_to_stream_response(
+                with Session(db.engine, expire_on_commit=False) as session:
+                    workflow_run = self._workflow_cycle_manager._get_workflow_run(
+                        session=session, workflow_run_id=self._workflow_run_id
+                    )
+                    iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
                         session=session,
                         task_id=self._application_generate_entity.task_id,
                         workflow_run=workflow_run,
@@ -423,9 +441,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
 
-                with Session(db.engine) as session:
-                    workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
-                    iter_next_resp = self._workflow_iteration_next_to_stream_response(
+                with Session(db.engine, expire_on_commit=False) as session:
+                    workflow_run = self._workflow_cycle_manager._get_workflow_run(
+                        session=session, workflow_run_id=self._workflow_run_id
+                    )
+                    iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
                         session=session,
                         task_id=self._application_generate_entity.task_id,
                         workflow_run=workflow_run,
@@ -437,9 +457,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
 
-                with Session(db.engine) as session:
-                    workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
-                    iter_finish_resp = self._workflow_iteration_completed_to_stream_response(
+                with Session(db.engine, expire_on_commit=False) as session:
+                    workflow_run = self._workflow_cycle_manager._get_workflow_run(
+                        session=session, workflow_run_id=self._workflow_run_id
+                    )
+                    iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
                         session=session,
                         task_id=self._application_generate_entity.task_id,
                         workflow_run=workflow_run,
@@ -454,8 +476,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 if not graph_runtime_state:
                     raise ValueError("workflow run not initialized.")
 
-                with Session(db.engine) as session:
-                    workflow_run = self._handle_workflow_run_success(
+                with Session(db.engine, expire_on_commit=False) as session:
+                    workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
                         session=session,
                         workflow_run_id=self._workflow_run_id,
                         start_at=graph_runtime_state.start_at,
@@ -466,21 +488,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                         trace_manager=trace_manager,
                     )
 
-                    workflow_finish_resp = self._workflow_finish_to_stream_response(
+                    workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
                         session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
                     )
                     session.commit()
 
                 yield workflow_finish_resp
-                self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
+                self._base_task_pipeline._queue_manager.publish(
+                    QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
+                )
             elif isinstance(event, QueueWorkflowPartialSuccessEvent):
                 if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
                 if not graph_runtime_state:
                     raise ValueError("graph runtime state not initialized.")
 
-                with Session(db.engine) as session:
-                    workflow_run = self._handle_workflow_run_partial_success(
+                with Session(db.engine, expire_on_commit=False) as session:
+                    workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
                         session=session,
                         workflow_run_id=self._workflow_run_id,
                         start_at=graph_runtime_state.start_at,
@@ -491,21 +515,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                         conversation_id=None,
                         trace_manager=trace_manager,
                     )
-                    workflow_finish_resp = self._workflow_finish_to_stream_response(
+                    workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
                         session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
                     )
                     session.commit()
 
                 yield workflow_finish_resp
-                self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
+                self._base_task_pipeline._queue_manager.publish(
+                    QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
+                )
             elif isinstance(event, QueueWorkflowFailedEvent):
                 if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
                 if not graph_runtime_state:
                     raise ValueError("graph runtime state not initialized.")
 
-                with Session(db.engine) as session:
-                    workflow_run = self._handle_workflow_run_failed(
+                with Session(db.engine, expire_on_commit=False) as session:
+                    workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
                         session=session,
                         workflow_run_id=self._workflow_run_id,
                         start_at=graph_runtime_state.start_at,
@@ -517,20 +543,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                         trace_manager=trace_manager,
                         exceptions_count=event.exceptions_count,
                     )
-                    workflow_finish_resp = self._workflow_finish_to_stream_response(
+                    workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
                         session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
                     )
                     err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
-                    err = self._handle_error(event=err_event, session=session, message_id=self._message_id)
+                    err = self._base_task_pipeline._handle_error(
+                        event=err_event, session=session, message_id=self._message_id
+                    )
                     session.commit()
 
                 yield workflow_finish_resp
-                yield self._error_to_stream_response(err)
+                yield self._base_task_pipeline._error_to_stream_response(err)
                 break
             elif isinstance(event, QueueStopEvent):
                 if self._workflow_run_id and graph_runtime_state:
-                    with Session(db.engine) as session:
-                        workflow_run = self._handle_workflow_run_failed(
+                    with Session(db.engine, expire_on_commit=False) as session:
+                        workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
                             session=session,
                             workflow_run_id=self._workflow_run_id,
                             start_at=graph_runtime_state.start_at,
@@ -541,7 +569,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                             conversation_id=self._conversation_id,
                             trace_manager=trace_manager,
                         )
-                        workflow_finish_resp = self._workflow_finish_to_stream_response(
+                        workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
                             session=session,
                             task_id=self._application_generate_entity.task_id,
                             workflow_run=workflow_run,
@@ -555,18 +583,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 yield self._message_end_to_stream_response()
                 break
             elif isinstance(event, QueueRetrieverResourcesEvent):
-                self._handle_retriever_resources(event)
+                self._message_cycle_manager._handle_retriever_resources(event)
 
-                with Session(db.engine) as session:
+                with Session(db.engine, expire_on_commit=False) as session:
                     message = self._get_message(session=session)
                     message.message_metadata = (
                         json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
                     )
                     session.commit()
             elif isinstance(event, QueueAnnotationReplyEvent):
-                self._handle_annotation_reply(event)
+                self._message_cycle_manager._handle_annotation_reply(event)
 
-                with Session(db.engine) as session:
+                with Session(db.engine, expire_on_commit=False) as session:
                     message = self._get_message(session=session)
                     message.message_metadata = (
                         json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
@@ -587,23 +615,27 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                     tts_publisher.publish(queue_message)
 
                 self._task_state.answer += delta_text
-                yield self._message_to_stream_response(
+                yield self._message_cycle_manager._message_to_stream_response(
                     answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
                 )
             elif isinstance(event, QueueMessageReplaceEvent):
                 # published by moderation
-                yield self._message_replace_to_stream_response(answer=event.text)
+                yield self._message_cycle_manager._message_replace_to_stream_response(answer=event.text)
             elif isinstance(event, QueueAdvancedChatMessageEndEvent):
                 if not graph_runtime_state:
                     raise ValueError("graph runtime state not initialized.")
 
-                output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
+                output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished(
+                    self._task_state.answer
+                )
                 if output_moderation_answer:
                     self._task_state.answer = output_moderation_answer
-                    yield self._message_replace_to_stream_response(answer=output_moderation_answer)
+                    yield self._message_cycle_manager._message_replace_to_stream_response(
+                        answer=output_moderation_answer
+                    )
 
                 # Save message
-                with Session(db.engine) as session:
+                with Session(db.engine, expire_on_commit=False) as session:
                     self._save_message(session=session, graph_runtime_state=graph_runtime_state)
                     session.commit()
 
@@ -621,7 +653,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
     def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
         message = self._get_message(session=session)
         message.answer = self._task_state.answer
-        message.provider_response_latency = time.perf_counter() - self._start_at
+        message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
         message.message_metadata = (
             json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
         )
@@ -685,20 +717,20 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
         :param text: text
         :return: True if output moderation should direct output, otherwise False
         """
-        if self._output_moderation_handler:
-            if self._output_moderation_handler.should_direct_output():
+        if self._base_task_pipeline._output_moderation_handler:
+            if self._base_task_pipeline._output_moderation_handler.should_direct_output():
                 # stop subscribe new token when output moderation should direct output
-                self._task_state.answer = self._output_moderation_handler.get_final_output()
-                self._queue_manager.publish(
+                self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output()
+                self._base_task_pipeline._queue_manager.publish(
                     QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
                 )
 
-                self._queue_manager.publish(
+                self._base_task_pipeline._queue_manager.publish(
                     QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
                 )
                 return True
             else:
-                self._output_moderation_handler.append_new_token(text)
+                self._base_task_pipeline._output_moderation_handler.append_new_token(text)
 
         return False
 

+ 89 - 73
api/core/app/apps/workflow/generate_task_pipeline.py

@@ -1,7 +1,7 @@
 import logging
 import time
 from collections.abc import Generator
-from typing import Any, Optional, Union
+from typing import Optional, Union
 
 from sqlalchemy.orm import Session
 
@@ -58,7 +58,6 @@ from models.workflow import (
     Workflow,
     WorkflowAppLog,
     WorkflowAppLogCreatedFrom,
-    WorkflowNodeExecution,
     WorkflowRun,
     WorkflowRunStatus,
 )
@@ -66,16 +65,11 @@ from models.workflow import (
 logger = logging.getLogger(__name__)
 
 
-class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage):
+class WorkflowAppGenerateTaskPipeline:
     """
     WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
     """
 
-    _task_state: WorkflowTaskState
-    _application_generate_entity: WorkflowAppGenerateEntity
-    _workflow_system_variables: dict[SystemVariableKey, Any]
-    _wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
-
     def __init__(
         self,
         application_generate_entity: WorkflowAppGenerateEntity,
@@ -84,7 +78,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
         user: Union[Account, EndUser],
         stream: bool,
     ) -> None:
-        super().__init__(
+        self._base_task_pipeline = BasedGenerateTaskPipeline(
             application_generate_entity=application_generate_entity,
             queue_manager=queue_manager,
             stream=stream,
@@ -101,19 +95,21 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
         else:
             raise ValueError(f"Invalid user type: {type(user)}")
 
+        self._workflow_cycle_manager = WorkflowCycleManage(
+            application_generate_entity=application_generate_entity,
+            workflow_system_variables={
+                SystemVariableKey.FILES: application_generate_entity.files,
+                SystemVariableKey.USER_ID: user_session_id,
+                SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
+                SystemVariableKey.WORKFLOW_ID: workflow.id,
+                SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
+            },
+        )
+
+        self._application_generate_entity = application_generate_entity
         self._workflow_id = workflow.id
         self._workflow_features_dict = workflow.features_dict
-
-        self._workflow_system_variables = {
-            SystemVariableKey.FILES: application_generate_entity.files,
-            SystemVariableKey.USER_ID: user_session_id,
-            SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
-            SystemVariableKey.WORKFLOW_ID: workflow.id,
-            SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
-        }
-
         self._task_state = WorkflowTaskState()
-        self._wip_workflow_node_executions = {}
         self._workflow_run_id = ""
 
     def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
@@ -122,7 +118,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
         :return:
         """
         generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
-        if self._stream:
+        if self._base_task_pipeline._stream:
             return self._to_stream_response(generator)
         else:
             return self._to_blocking_response(generator)
@@ -237,29 +233,29 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
         """
         graph_runtime_state = None
 
-        for queue_message in self._queue_manager.listen():
+        for queue_message in self._base_task_pipeline._queue_manager.listen():
             event = queue_message.event
 
             if isinstance(event, QueuePingEvent):
-                yield self._ping_stream_response()
+                yield self._base_task_pipeline._ping_stream_response()
             elif isinstance(event, QueueErrorEvent):
-                err = self._handle_error(event=event)
-                yield self._error_to_stream_response(err)
+                err = self._base_task_pipeline._handle_error(event=event)
+                yield self._base_task_pipeline._error_to_stream_response(err)
                 break
             elif isinstance(event, QueueWorkflowStartedEvent):
                 # override graph runtime state
                 graph_runtime_state = event.graph_runtime_state
 
-                with Session(db.engine) as session:
+                with Session(db.engine, expire_on_commit=False) as session:
                     # init workflow run
-                    workflow_run = self._handle_workflow_run_start(
+                    workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
                         session=session,
                         workflow_id=self._workflow_id,
                         user_id=self._user_id,
                         created_by_role=self._created_by_role,
                     )
                     self._workflow_run_id = workflow_run.id
-                    start_resp = self._workflow_start_to_stream_response(
+                    start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
                         session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
                     )
                     session.commit()
@@ -271,12 +267,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
             ):
                 if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
-                with Session(db.engine) as session:
-                    workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
-                    workflow_node_execution = self._handle_workflow_node_execution_retried(
+                with Session(db.engine, expire_on_commit=False) as session:
+                    workflow_run = self._workflow_cycle_manager._get_workflow_run(
+                        session=session, workflow_run_id=self._workflow_run_id
+                    )
+                    workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
                         session=session, workflow_run=workflow_run, event=event
                     )
-                    response = self._workflow_node_retry_to_stream_response(
+                    response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
                         session=session,
                         event=event,
                         task_id=self._application_generate_entity.task_id,
@@ -290,12 +288,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                 if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
 
-                with Session(db.engine) as session:
-                    workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
-                    workflow_node_execution = self._handle_node_execution_start(
+                with Session(db.engine, expire_on_commit=False) as session:
+                    workflow_run = self._workflow_cycle_manager._get_workflow_run(
+                        session=session, workflow_run_id=self._workflow_run_id
+                    )
+                    workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
                         session=session, workflow_run=workflow_run, event=event
                     )
-                    node_start_response = self._workflow_node_start_to_stream_response(
+                    node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
                         session=session,
                         event=event,
                         task_id=self._application_generate_entity.task_id,
@@ -306,9 +306,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                 if node_start_response:
                     yield node_start_response
             elif isinstance(event, QueueNodeSucceededEvent):
-                with Session(db.engine) as session:
-                    workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event)
-                    node_success_response = self._workflow_node_finish_to_stream_response(
+                with Session(db.engine, expire_on_commit=False) as session:
+                    workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
+                        session=session, event=event
+                    )
+                    node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
                         session=session,
                         event=event,
                         task_id=self._application_generate_entity.task_id,
@@ -319,12 +321,12 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                 if node_success_response:
                     yield node_success_response
             elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
-                with Session(db.engine) as session:
-                    workflow_node_execution = self._handle_workflow_node_execution_failed(
+                with Session(db.engine, expire_on_commit=False) as session:
+                    workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
                         session=session,
                         event=event,
                     )
-                    node_failed_response = self._workflow_node_finish_to_stream_response(
+                    node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
                         session=session,
                         event=event,
                         task_id=self._application_generate_entity.task_id,
@@ -339,13 +341,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                 if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
 
-                with Session(db.engine) as session:
-                    workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
-                    parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response(
-                        session=session,
-                        task_id=self._application_generate_entity.task_id,
-                        workflow_run=workflow_run,
-                        event=event,
+                with Session(db.engine, expire_on_commit=False) as session:
+                    workflow_run = self._workflow_cycle_manager._get_workflow_run(
+                        session=session, workflow_run_id=self._workflow_run_id
+                    )
+                    parallel_start_resp = (
+                        self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
+                            session=session,
+                            task_id=self._application_generate_entity.task_id,
+                            workflow_run=workflow_run,
+                            event=event,
+                        )
                     )
 
                 yield parallel_start_resp
@@ -354,13 +360,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                 if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
 
-                with Session(db.engine) as session:
-                    workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
-                    parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response(
-                        session=session,
-                        task_id=self._application_generate_entity.task_id,
-                        workflow_run=workflow_run,
-                        event=event,
+                with Session(db.engine, expire_on_commit=False) as session:
+                    workflow_run = self._workflow_cycle_manager._get_workflow_run(
+                        session=session, workflow_run_id=self._workflow_run_id
+                    )
+                    parallel_finish_resp = (
+                        self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
+                            session=session,
+                            task_id=self._application_generate_entity.task_id,
+                            workflow_run=workflow_run,
+                            event=event,
+                        )
                     )
 
                 yield parallel_finish_resp
@@ -369,9 +379,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                 if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
 
-                with Session(db.engine) as session:
-                    workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
-                    iter_start_resp = self._workflow_iteration_start_to_stream_response(
+                with Session(db.engine, expire_on_commit=False) as session:
+                    workflow_run = self._workflow_cycle_manager._get_workflow_run(
+                        session=session, workflow_run_id=self._workflow_run_id
+                    )
+                    iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
                         session=session,
                         task_id=self._application_generate_entity.task_id,
                         workflow_run=workflow_run,
@@ -384,9 +396,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                 if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
 
-                with Session(db.engine) as session:
-                    workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
-                    iter_next_resp = self._workflow_iteration_next_to_stream_response(
+                with Session(db.engine, expire_on_commit=False) as session:
+                    workflow_run = self._workflow_cycle_manager._get_workflow_run(
+                        session=session, workflow_run_id=self._workflow_run_id
+                    )
+                    iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
                         session=session,
                         task_id=self._application_generate_entity.task_id,
                         workflow_run=workflow_run,
@@ -399,9 +413,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                 if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
 
-                with Session(db.engine) as session:
-                    workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
-                    iter_finish_resp = self._workflow_iteration_completed_to_stream_response(
+                with Session(db.engine, expire_on_commit=False) as session:
+                    workflow_run = self._workflow_cycle_manager._get_workflow_run(
+                        session=session, workflow_run_id=self._workflow_run_id
+                    )
+                    iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
                         session=session,
                         task_id=self._application_generate_entity.task_id,
                         workflow_run=workflow_run,
@@ -416,8 +432,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                 if not graph_runtime_state:
                     raise ValueError("graph runtime state not initialized.")
 
-                with Session(db.engine) as session:
-                    workflow_run = self._handle_workflow_run_success(
+                with Session(db.engine, expire_on_commit=False) as session:
+                    workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
                         session=session,
                         workflow_run_id=self._workflow_run_id,
                         start_at=graph_runtime_state.start_at,
@@ -431,7 +447,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                     # save workflow app log
                     self._save_workflow_app_log(session=session, workflow_run=workflow_run)
 
-                    workflow_finish_resp = self._workflow_finish_to_stream_response(
+                    workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
                         session=session,
                         task_id=self._application_generate_entity.task_id,
                         workflow_run=workflow_run,
@@ -445,8 +461,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                 if not graph_runtime_state:
                     raise ValueError("graph runtime state not initialized.")
 
-                with Session(db.engine) as session:
-                    workflow_run = self._handle_workflow_run_partial_success(
+                with Session(db.engine, expire_on_commit=False) as session:
+                    workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
                         session=session,
                         workflow_run_id=self._workflow_run_id,
                         start_at=graph_runtime_state.start_at,
@@ -461,7 +477,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                     # save workflow app log
                     self._save_workflow_app_log(session=session, workflow_run=workflow_run)
 
-                    workflow_finish_resp = self._workflow_finish_to_stream_response(
+                    workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
                         session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
                     )
                     session.commit()
@@ -473,8 +489,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                 if not graph_runtime_state:
                     raise ValueError("graph runtime state not initialized.")
 
-                with Session(db.engine) as session:
-                    workflow_run = self._handle_workflow_run_failed(
+                with Session(db.engine, expire_on_commit=False) as session:
+                    workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
                         session=session,
                         workflow_run_id=self._workflow_run_id,
                         start_at=graph_runtime_state.start_at,
@@ -492,7 +508,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                     # save workflow app log
                     self._save_workflow_app_log(session=session, workflow_run=workflow_run)
 
-                    workflow_finish_resp = self._workflow_finish_to_stream_response(
+                    workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
                         session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
                     )
                     session.commit()

+ 0 - 11
api/core/app/task_pipeline/based_generate_task_pipeline.py

@@ -15,7 +15,6 @@ from core.app.entities.queue_entities import (
 from core.app.entities.task_entities import (
     ErrorStreamResponse,
     PingStreamResponse,
-    TaskState,
 )
 from core.errors.error import QuotaExceededError
 from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
@@ -30,22 +29,12 @@ class BasedGenerateTaskPipeline:
     BasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
     """
 
-    _task_state: TaskState
-    _application_generate_entity: AppGenerateEntity
-
     def __init__(
         self,
         application_generate_entity: AppGenerateEntity,
         queue_manager: AppQueueManager,
         stream: bool,
     ) -> None:
-        """
-        Initialize GenerateTaskPipeline.
-        :param application_generate_entity: application generate entity
-        :param queue_manager: queue manager
-        :param user: user
-        :param stream: stream
-        """
         self._application_generate_entity = application_generate_entity
         self._queue_manager = queue_manager
         self._start_at = time.perf_counter()

+ 13 - 4
api/core/app/task_pipeline/message_cycle_manage.py

@@ -31,10 +31,19 @@ from services.annotation_service import AppAnnotationService
 
 
 class MessageCycleManage:
-    _application_generate_entity: Union[
-        ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity
-    ]
-    _task_state: Union[EasyUITaskState, WorkflowTaskState]
+    def __init__(
+        self,
+        *,
+        application_generate_entity: Union[
+            ChatAppGenerateEntity,
+            CompletionAppGenerateEntity,
+            AgentChatAppGenerateEntity,
+            AdvancedChatAppGenerateEntity,
+        ],
+        task_state: Union[EasyUITaskState, WorkflowTaskState],
+    ) -> None:
+        self._application_generate_entity = application_generate_entity
+        self._task_state = task_state
 
     def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
         """

+ 35 - 23
api/core/app/task_pipeline/workflow_cycle_manage.py

@@ -34,7 +34,6 @@ from core.app.entities.task_entities import (
     ParallelBranchStartStreamResponse,
     WorkflowFinishStreamResponse,
     WorkflowStartStreamResponse,
-    WorkflowTaskState,
 )
 from core.file import FILE_MODEL_IDENTITY, File
 from core.model_runtime.utils.encoders import jsonable_encoder
@@ -58,13 +57,20 @@ from models.workflow import (
     WorkflowRunStatus,
 )
 
-from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError
+from .exc import WorkflowRunNotFoundError
 
 
 class WorkflowCycleManage:
-    _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
-    _task_state: WorkflowTaskState
-    _workflow_system_variables: dict[SystemVariableKey, Any]
+    def __init__(
+        self,
+        *,
+        application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
+        workflow_system_variables: dict[SystemVariableKey, Any],
+    ) -> None:
+        self._workflow_run: WorkflowRun | None = None
+        self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {}
+        self._application_generate_entity = application_generate_entity
+        self._workflow_system_variables = workflow_system_variables
 
     def _handle_workflow_run_start(
         self,
@@ -240,7 +246,7 @@ class WorkflowCycleManage:
         workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
         workflow_run.exceptions_count = exceptions_count
 
-        stmt = select(WorkflowNodeExecution).where(
+        stmt = select(WorkflowNodeExecution.node_execution_id).where(
             WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
             WorkflowNodeExecution.app_id == workflow_run.app_id,
             WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
@@ -248,16 +254,18 @@ class WorkflowCycleManage:
             WorkflowNodeExecution.workflow_run_id == workflow_run.id,
             WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
         )
-
-        running_workflow_node_executions = session.scalars(stmt).all()
+        ids = session.scalars(stmt).all()
+        # Use self._get_workflow_node_execution here to make sure the cache is updated
+        running_workflow_node_executions = [
+            self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id
+        ]
 
         for workflow_node_execution in running_workflow_node_executions:
+            now = datetime.now(UTC).replace(tzinfo=None)
             workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
             workflow_node_execution.error = error
-            workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
-            workflow_node_execution.elapsed_time = (
-                workflow_node_execution.finished_at - workflow_node_execution.created_at
-            ).total_seconds()
+            workflow_node_execution.finished_at = now
+            workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds()
 
         if trace_manager:
             trace_manager.add_trace_task(
@@ -299,6 +307,8 @@ class WorkflowCycleManage:
         workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
 
         session.add(workflow_node_execution)
+
+        self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
         return workflow_node_execution
 
     def _handle_workflow_node_execution_success(
@@ -326,6 +336,7 @@ class WorkflowCycleManage:
         workflow_node_execution.finished_at = finished_at
         workflow_node_execution.elapsed_time = elapsed_time
 
+        workflow_node_execution = session.merge(workflow_node_execution)
         return workflow_node_execution
 
     def _handle_workflow_node_execution_failed(
@@ -365,6 +376,7 @@ class WorkflowCycleManage:
         workflow_node_execution.elapsed_time = elapsed_time
         workflow_node_execution.execution_metadata = execution_metadata
 
+        workflow_node_execution = session.merge(workflow_node_execution)
         return workflow_node_execution
 
     def _handle_workflow_node_execution_retried(
@@ -416,6 +428,8 @@ class WorkflowCycleManage:
         workflow_node_execution.index = event.node_run_index
 
         session.add(workflow_node_execution)
+
+        self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
         return workflow_node_execution
 
     #################################################
@@ -812,22 +826,20 @@ class WorkflowCycleManage:
         return None
 
     def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
-        """
-        Refetch workflow run
-        :param workflow_run_id: workflow run id
-        :return:
-        """
+        if self._workflow_run and self._workflow_run.id == workflow_run_id:
+            cached_workflow_run = self._workflow_run
+            cached_workflow_run = session.merge(cached_workflow_run)
+            return cached_workflow_run
         stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
         workflow_run = session.scalar(stmt)
         if not workflow_run:
             raise WorkflowRunNotFoundError(workflow_run_id)
+        self._workflow_run = workflow_run
 
         return workflow_run
 
     def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution:
-        stmt = select(WorkflowNodeExecution).where(WorkflowNodeExecution.node_execution_id == node_execution_id)
-        workflow_node_execution = session.scalar(stmt)
-        if not workflow_node_execution:
-            raise WorkflowNodeExecutionNotFoundError(node_execution_id)
-
-        return workflow_node_execution
+        if node_execution_id not in self._workflow_node_executions:
+            raise ValueError(f"Workflow node execution not found: {node_execution_id}")
+        cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
+        return cached_workflow_node_execution

+ 0 - 0
api/core/app/task_pipeline/workflow_cycle_state_manager.py