Преглед на файлове

fix(workflow_service): assign UUID to workflow_node_execution id and update optional fields in WorkflowRun and WorkflowNodeExecution models (#12096)

Signed-off-by: -LAN- <laipz8200@outlook.com>
-LAN- преди 4 месеца
родител
ревизия
822af70dce

+ 119 - 68
api/core/app/apps/advanced_chat/generate_task_pipeline.py

@@ -68,7 +68,6 @@ from models.enums import CreatedByRole
 from models.workflow import (
 from models.workflow import (
     Workflow,
     Workflow,
     WorkflowNodeExecution,
     WorkflowNodeExecution,
-    WorkflowRun,
     WorkflowRunStatus,
     WorkflowRunStatus,
 )
 )
 
 
@@ -104,10 +103,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
         )
         )
 
 
         if isinstance(user, EndUser):
         if isinstance(user, EndUser):
-            self._user_id = user.session_id
+            self._user_id = user.id
+            user_session_id = user.session_id
             self._created_by_role = CreatedByRole.END_USER
             self._created_by_role = CreatedByRole.END_USER
         elif isinstance(user, Account):
         elif isinstance(user, Account):
             self._user_id = user.id
             self._user_id = user.id
+            user_session_id = user.id
             self._created_by_role = CreatedByRole.ACCOUNT
             self._created_by_role = CreatedByRole.ACCOUNT
         else:
         else:
             raise NotImplementedError(f"User type not supported: {type(user)}")
             raise NotImplementedError(f"User type not supported: {type(user)}")
@@ -125,7 +126,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
             SystemVariableKey.QUERY: message.query,
             SystemVariableKey.QUERY: message.query,
             SystemVariableKey.FILES: application_generate_entity.files,
             SystemVariableKey.FILES: application_generate_entity.files,
             SystemVariableKey.CONVERSATION_ID: conversation.id,
             SystemVariableKey.CONVERSATION_ID: conversation.id,
-            SystemVariableKey.USER_ID: self._user_id,
+            SystemVariableKey.USER_ID: user_session_id,
             SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
             SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
             SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
             SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
             SystemVariableKey.WORKFLOW_ID: workflow.id,
             SystemVariableKey.WORKFLOW_ID: workflow.id,
@@ -137,6 +138,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
 
 
         self._conversation_name_generate_thread = None
         self._conversation_name_generate_thread = None
         self._recorded_files: list[Mapping[str, Any]] = []
         self._recorded_files: list[Mapping[str, Any]] = []
+        self._workflow_run_id = ""
 
 
     def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
     def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
         """
         """
@@ -266,7 +268,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
         """
         """
         # init fake graph runtime state
         # init fake graph runtime state
         graph_runtime_state: Optional[GraphRuntimeState] = None
         graph_runtime_state: Optional[GraphRuntimeState] = None
-        workflow_run: Optional[WorkflowRun] = None
 
 
         for queue_message in self._queue_manager.listen():
         for queue_message in self._queue_manager.listen():
             event = queue_message.event
             event = queue_message.event
@@ -291,111 +292,163 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                         user_id=self._user_id,
                         user_id=self._user_id,
                         created_by_role=self._created_by_role,
                         created_by_role=self._created_by_role,
                     )
                     )
+                    self._workflow_run_id = workflow_run.id
                     message = self._get_message(session=session)
                     message = self._get_message(session=session)
                     if not message:
                     if not message:
                         raise ValueError(f"Message not found: {self._message_id}")
                         raise ValueError(f"Message not found: {self._message_id}")
                     message.workflow_run_id = workflow_run.id
                     message.workflow_run_id = workflow_run.id
-                    session.commit()
-
                     workflow_start_resp = self._workflow_start_to_stream_response(
                     workflow_start_resp = self._workflow_start_to_stream_response(
                         session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
                         session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
                     )
                     )
+                    session.commit()
+
                 yield workflow_start_resp
                 yield workflow_start_resp
             elif isinstance(
             elif isinstance(
                 event,
                 event,
                 QueueNodeRetryEvent,
                 QueueNodeRetryEvent,
             ):
             ):
-                if not workflow_run:
+                if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
                     raise ValueError("workflow run not initialized.")
-                workflow_node_execution = self._handle_workflow_node_execution_retried(
-                    workflow_run=workflow_run, event=event
-                )
 
 
-                node_retry_resp = self._workflow_node_retry_to_stream_response(
-                    event=event,
-                    task_id=self._application_generate_entity.task_id,
-                    workflow_node_execution=workflow_node_execution,
-                )
+                with Session(db.engine) as session:
+                    workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
+                    workflow_node_execution = self._handle_workflow_node_execution_retried(
+                        session=session, workflow_run=workflow_run, event=event
+                    )
+                    node_retry_resp = self._workflow_node_retry_to_stream_response(
+                        session=session,
+                        event=event,
+                        task_id=self._application_generate_entity.task_id,
+                        workflow_node_execution=workflow_node_execution,
+                    )
+                    session.commit()
 
 
                 if node_retry_resp:
                 if node_retry_resp:
                     yield node_retry_resp
                     yield node_retry_resp
             elif isinstance(event, QueueNodeStartedEvent):
             elif isinstance(event, QueueNodeStartedEvent):
-                if not workflow_run:
+                if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
                     raise ValueError("workflow run not initialized.")
 
 
-                workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
+                with Session(db.engine) as session:
+                    workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
+                    workflow_node_execution = self._handle_node_execution_start(
+                        session=session, workflow_run=workflow_run, event=event
+                    )
 
 
-                node_start_resp = self._workflow_node_start_to_stream_response(
-                    event=event,
-                    task_id=self._application_generate_entity.task_id,
-                    workflow_node_execution=workflow_node_execution,
-                )
+                    node_start_resp = self._workflow_node_start_to_stream_response(
+                        session=session,
+                        event=event,
+                        task_id=self._application_generate_entity.task_id,
+                        workflow_node_execution=workflow_node_execution,
+                    )
+                    session.commit()
 
 
                 if node_start_resp:
                 if node_start_resp:
                     yield node_start_resp
                     yield node_start_resp
             elif isinstance(event, QueueNodeSucceededEvent):
             elif isinstance(event, QueueNodeSucceededEvent):
-                workflow_node_execution = self._handle_workflow_node_execution_success(event)
-
                 # Record files if it's an answer node or end node
                 # Record files if it's an answer node or end node
                 if event.node_type in [NodeType.ANSWER, NodeType.END]:
                 if event.node_type in [NodeType.ANSWER, NodeType.END]:
                     self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
                     self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
 
 
-                node_finish_resp = self._workflow_node_finish_to_stream_response(
-                    event=event,
-                    task_id=self._application_generate_entity.task_id,
-                    workflow_node_execution=workflow_node_execution,
-                )
+                with Session(db.engine) as session:
+                    workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event)
+
+                    node_finish_resp = self._workflow_node_finish_to_stream_response(
+                        session=session,
+                        event=event,
+                        task_id=self._application_generate_entity.task_id,
+                        workflow_node_execution=workflow_node_execution,
+                    )
+                    session.commit()
 
 
                 if node_finish_resp:
                 if node_finish_resp:
                     yield node_finish_resp
                     yield node_finish_resp
             elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
             elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
-                workflow_node_execution = self._handle_workflow_node_execution_failed(event)
+                with Session(db.engine) as session:
+                    workflow_node_execution = self._handle_workflow_node_execution_failed(session=session, event=event)
+
+                    node_finish_resp = self._workflow_node_finish_to_stream_response(
+                        session=session,
+                        event=event,
+                        task_id=self._application_generate_entity.task_id,
+                        workflow_node_execution=workflow_node_execution,
+                    )
+                    session.commit()
 
 
-                node_finish_resp = self._workflow_node_finish_to_stream_response(
-                    event=event,
-                    task_id=self._application_generate_entity.task_id,
-                    workflow_node_execution=workflow_node_execution,
-                )
                 if node_finish_resp:
                 if node_finish_resp:
                     yield node_finish_resp
                     yield node_finish_resp
-
             elif isinstance(event, QueueParallelBranchRunStartedEvent):
             elif isinstance(event, QueueParallelBranchRunStartedEvent):
-                if not workflow_run:
+                if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
                     raise ValueError("workflow run not initialized.")
 
 
-                yield self._workflow_parallel_branch_start_to_stream_response(
-                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
-                )
+                with Session(db.engine) as session:
+                    workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
+                    parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response(
+                        session=session,
+                        task_id=self._application_generate_entity.task_id,
+                        workflow_run=workflow_run,
+                        event=event,
+                    )
+
+                yield parallel_start_resp
             elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
             elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
-                if not workflow_run:
+                if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
                     raise ValueError("workflow run not initialized.")
 
 
-                yield self._workflow_parallel_branch_finished_to_stream_response(
-                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
-                )
+                with Session(db.engine) as session:
+                    workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
+                    parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response(
+                        session=session,
+                        task_id=self._application_generate_entity.task_id,
+                        workflow_run=workflow_run,
+                        event=event,
+                    )
+
+                yield parallel_finish_resp
             elif isinstance(event, QueueIterationStartEvent):
             elif isinstance(event, QueueIterationStartEvent):
-                if not workflow_run:
+                if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
                     raise ValueError("workflow run not initialized.")
 
 
-                yield self._workflow_iteration_start_to_stream_response(
-                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
-                )
+                with Session(db.engine) as session:
+                    workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
+                    iter_start_resp = self._workflow_iteration_start_to_stream_response(
+                        session=session,
+                        task_id=self._application_generate_entity.task_id,
+                        workflow_run=workflow_run,
+                        event=event,
+                    )
+
+                yield iter_start_resp
             elif isinstance(event, QueueIterationNextEvent):
             elif isinstance(event, QueueIterationNextEvent):
-                if not workflow_run:
+                if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
                     raise ValueError("workflow run not initialized.")
 
 
-                yield self._workflow_iteration_next_to_stream_response(
-                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
-                )
+                with Session(db.engine) as session:
+                    workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
+                    iter_next_resp = self._workflow_iteration_next_to_stream_response(
+                        session=session,
+                        task_id=self._application_generate_entity.task_id,
+                        workflow_run=workflow_run,
+                        event=event,
+                    )
+
+                yield iter_next_resp
             elif isinstance(event, QueueIterationCompletedEvent):
             elif isinstance(event, QueueIterationCompletedEvent):
-                if not workflow_run:
+                if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
                     raise ValueError("workflow run not initialized.")
 
 
-                yield self._workflow_iteration_completed_to_stream_response(
-                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
-                )
+                with Session(db.engine) as session:
+                    workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
+                    iter_finish_resp = self._workflow_iteration_completed_to_stream_response(
+                        session=session,
+                        task_id=self._application_generate_entity.task_id,
+                        workflow_run=workflow_run,
+                        event=event,
+                    )
+
+                yield iter_finish_resp
             elif isinstance(event, QueueWorkflowSucceededEvent):
             elif isinstance(event, QueueWorkflowSucceededEvent):
-                if not workflow_run:
+                if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
                     raise ValueError("workflow run not initialized.")
 
 
                 if not graph_runtime_state:
                 if not graph_runtime_state:
@@ -404,7 +457,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 with Session(db.engine) as session:
                 with Session(db.engine) as session:
                     workflow_run = self._handle_workflow_run_success(
                     workflow_run = self._handle_workflow_run_success(
                         session=session,
                         session=session,
-                        workflow_run=workflow_run,
+                        workflow_run_id=self._workflow_run_id,
                         start_at=graph_runtime_state.start_at,
                         start_at=graph_runtime_state.start_at,
                         total_tokens=graph_runtime_state.total_tokens,
                         total_tokens=graph_runtime_state.total_tokens,
                         total_steps=graph_runtime_state.node_run_steps,
                         total_steps=graph_runtime_state.node_run_steps,
@@ -421,16 +474,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 yield workflow_finish_resp
                 yield workflow_finish_resp
                 self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
                 self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
             elif isinstance(event, QueueWorkflowPartialSuccessEvent):
             elif isinstance(event, QueueWorkflowPartialSuccessEvent):
-                if not workflow_run:
+                if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
                     raise ValueError("workflow run not initialized.")
-
                 if not graph_runtime_state:
                 if not graph_runtime_state:
                     raise ValueError("graph runtime state not initialized.")
                     raise ValueError("graph runtime state not initialized.")
 
 
                 with Session(db.engine) as session:
                 with Session(db.engine) as session:
                     workflow_run = self._handle_workflow_run_partial_success(
                     workflow_run = self._handle_workflow_run_partial_success(
                         session=session,
                         session=session,
-                        workflow_run=workflow_run,
+                        workflow_run_id=self._workflow_run_id,
                         start_at=graph_runtime_state.start_at,
                         start_at=graph_runtime_state.start_at,
                         total_tokens=graph_runtime_state.total_tokens,
                         total_tokens=graph_runtime_state.total_tokens,
                         total_steps=graph_runtime_state.node_run_steps,
                         total_steps=graph_runtime_state.node_run_steps,
@@ -439,7 +491,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                         conversation_id=None,
                         conversation_id=None,
                         trace_manager=trace_manager,
                         trace_manager=trace_manager,
                     )
                     )
-
                     workflow_finish_resp = self._workflow_finish_to_stream_response(
                     workflow_finish_resp = self._workflow_finish_to_stream_response(
                         session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
                         session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
                     )
                     )
@@ -448,16 +499,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 yield workflow_finish_resp
                 yield workflow_finish_resp
                 self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
                 self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
             elif isinstance(event, QueueWorkflowFailedEvent):
             elif isinstance(event, QueueWorkflowFailedEvent):
-                if not workflow_run:
+                if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
                     raise ValueError("workflow run not initialized.")
-
                 if not graph_runtime_state:
                 if not graph_runtime_state:
                     raise ValueError("graph runtime state not initialized.")
                     raise ValueError("graph runtime state not initialized.")
 
 
                 with Session(db.engine) as session:
                 with Session(db.engine) as session:
                     workflow_run = self._handle_workflow_run_failed(
                     workflow_run = self._handle_workflow_run_failed(
                         session=session,
                         session=session,
-                        workflow_run=workflow_run,
+                        workflow_run_id=self._workflow_run_id,
                         start_at=graph_runtime_state.start_at,
                         start_at=graph_runtime_state.start_at,
                         total_tokens=graph_runtime_state.total_tokens,
                         total_tokens=graph_runtime_state.total_tokens,
                         total_steps=graph_runtime_state.node_run_steps,
                         total_steps=graph_runtime_state.node_run_steps,
@@ -473,15 +523,16 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                     err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
                     err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
                     err = self._handle_error(event=err_event, session=session, message_id=self._message_id)
                     err = self._handle_error(event=err_event, session=session, message_id=self._message_id)
                     session.commit()
                     session.commit()
+
                 yield workflow_finish_resp
                 yield workflow_finish_resp
                 yield self._error_to_stream_response(err)
                 yield self._error_to_stream_response(err)
                 break
                 break
             elif isinstance(event, QueueStopEvent):
             elif isinstance(event, QueueStopEvent):
-                if workflow_run and graph_runtime_state:
+                if self._workflow_run_id and graph_runtime_state:
                     with Session(db.engine) as session:
                     with Session(db.engine) as session:
                         workflow_run = self._handle_workflow_run_failed(
                         workflow_run = self._handle_workflow_run_failed(
                             session=session,
                             session=session,
-                            workflow_run=workflow_run,
+                            workflow_run_id=self._workflow_run_id,
                             start_at=graph_runtime_state.start_at,
                             start_at=graph_runtime_state.start_at,
                             total_tokens=graph_runtime_state.total_tokens,
                             total_tokens=graph_runtime_state.total_tokens,
                             total_steps=graph_runtime_state.node_run_steps,
                             total_steps=graph_runtime_state.node_run_steps,
@@ -490,7 +541,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                             conversation_id=self._conversation_id,
                             conversation_id=self._conversation_id,
                             trace_manager=trace_manager,
                             trace_manager=trace_manager,
                         )
                         )
-
                         workflow_finish_resp = self._workflow_finish_to_stream_response(
                         workflow_finish_resp = self._workflow_finish_to_stream_response(
                             session=session,
                             session=session,
                             task_id=self._application_generate_entity.task_id,
                             task_id=self._application_generate_entity.task_id,
@@ -499,6 +549,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                         # Save message
                         # Save message
                         self._save_message(session=session, graph_runtime_state=graph_runtime_state)
                         self._save_message(session=session, graph_runtime_state=graph_runtime_state)
                         session.commit()
                         session.commit()
+
                     yield workflow_finish_resp
                     yield workflow_finish_resp
 
 
                 yield self._message_end_to_stream_response()
                 yield self._message_end_to_stream_response()

+ 123 - 63
api/core/app/apps/workflow/generate_task_pipeline.py

@@ -91,10 +91,12 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
         )
         )
 
 
         if isinstance(user, EndUser):
         if isinstance(user, EndUser):
-            self._user_id = user.session_id
+            self._user_id = user.id
+            user_session_id = user.session_id
             self._created_by_role = CreatedByRole.END_USER
             self._created_by_role = CreatedByRole.END_USER
         elif isinstance(user, Account):
         elif isinstance(user, Account):
             self._user_id = user.id
             self._user_id = user.id
+            user_session_id = user.id
             self._created_by_role = CreatedByRole.ACCOUNT
             self._created_by_role = CreatedByRole.ACCOUNT
         else:
         else:
             raise ValueError(f"Invalid user type: {type(user)}")
             raise ValueError(f"Invalid user type: {type(user)}")
@@ -104,7 +106,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
 
 
         self._workflow_system_variables = {
         self._workflow_system_variables = {
             SystemVariableKey.FILES: application_generate_entity.files,
             SystemVariableKey.FILES: application_generate_entity.files,
-            SystemVariableKey.USER_ID: self._user_id,
+            SystemVariableKey.USER_ID: user_session_id,
             SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
             SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
             SystemVariableKey.WORKFLOW_ID: workflow.id,
             SystemVariableKey.WORKFLOW_ID: workflow.id,
             SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
             SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
@@ -112,6 +114,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
 
 
         self._task_state = WorkflowTaskState()
         self._task_state = WorkflowTaskState()
         self._wip_workflow_node_executions = {}
         self._wip_workflow_node_executions = {}
+        self._workflow_run_id = ""
 
 
     def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
     def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
         """
         """
@@ -233,7 +236,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
         :return:
         :return:
         """
         """
         graph_runtime_state = None
         graph_runtime_state = None
-        workflow_run = None
 
 
         for queue_message in self._queue_manager.listen():
         for queue_message in self._queue_manager.listen():
             event = queue_message.event
             event = queue_message.event
@@ -256,111 +258,168 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                         user_id=self._user_id,
                         user_id=self._user_id,
                         created_by_role=self._created_by_role,
                         created_by_role=self._created_by_role,
                     )
                     )
+                    self._workflow_run_id = workflow_run.id
                     start_resp = self._workflow_start_to_stream_response(
                     start_resp = self._workflow_start_to_stream_response(
                         session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
                         session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
                     )
                     )
                     session.commit()
                     session.commit()
+
                 yield start_resp
                 yield start_resp
             elif isinstance(
             elif isinstance(
                 event,
                 event,
                 QueueNodeRetryEvent,
                 QueueNodeRetryEvent,
             ):
             ):
-                if not workflow_run:
+                if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
                     raise ValueError("workflow run not initialized.")
-                workflow_node_execution = self._handle_workflow_node_execution_retried(
-                    workflow_run=workflow_run, event=event
-                )
-
-                response = self._workflow_node_retry_to_stream_response(
-                    event=event,
-                    task_id=self._application_generate_entity.task_id,
-                    workflow_node_execution=workflow_node_execution,
-                )
+                with Session(db.engine) as session:
+                    workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
+                    workflow_node_execution = self._handle_workflow_node_execution_retried(
+                        session=session, workflow_run=workflow_run, event=event
+                    )
+                    response = self._workflow_node_retry_to_stream_response(
+                        session=session,
+                        event=event,
+                        task_id=self._application_generate_entity.task_id,
+                        workflow_node_execution=workflow_node_execution,
+                    )
+                    session.commit()
 
 
                 if response:
                 if response:
                     yield response
                     yield response
             elif isinstance(event, QueueNodeStartedEvent):
             elif isinstance(event, QueueNodeStartedEvent):
-                if not workflow_run:
+                if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
                     raise ValueError("workflow run not initialized.")
 
 
-                workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
-
-                node_start_response = self._workflow_node_start_to_stream_response(
-                    event=event,
-                    task_id=self._application_generate_entity.task_id,
-                    workflow_node_execution=workflow_node_execution,
-                )
+                with Session(db.engine) as session:
+                    workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
+                    workflow_node_execution = self._handle_node_execution_start(
+                        session=session, workflow_run=workflow_run, event=event
+                    )
+                    node_start_response = self._workflow_node_start_to_stream_response(
+                        session=session,
+                        event=event,
+                        task_id=self._application_generate_entity.task_id,
+                        workflow_node_execution=workflow_node_execution,
+                    )
+                    session.commit()
 
 
                 if node_start_response:
                 if node_start_response:
                     yield node_start_response
                     yield node_start_response
             elif isinstance(event, QueueNodeSucceededEvent):
             elif isinstance(event, QueueNodeSucceededEvent):
-                workflow_node_execution = self._handle_workflow_node_execution_success(event)
-
-                node_success_response = self._workflow_node_finish_to_stream_response(
-                    event=event,
-                    task_id=self._application_generate_entity.task_id,
-                    workflow_node_execution=workflow_node_execution,
-                )
+                with Session(db.engine) as session:
+                    workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event)
+                    node_success_response = self._workflow_node_finish_to_stream_response(
+                        session=session,
+                        event=event,
+                        task_id=self._application_generate_entity.task_id,
+                        workflow_node_execution=workflow_node_execution,
+                    )
+                    session.commit()
 
 
                 if node_success_response:
                 if node_success_response:
                     yield node_success_response
                     yield node_success_response
             elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
             elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
-                workflow_node_execution = self._handle_workflow_node_execution_failed(event)
+                with Session(db.engine) as session:
+                    workflow_node_execution = self._handle_workflow_node_execution_failed(
+                        session=session,
+                        event=event,
+                    )
+                    node_failed_response = self._workflow_node_finish_to_stream_response(
+                        session=session,
+                        event=event,
+                        task_id=self._application_generate_entity.task_id,
+                        workflow_node_execution=workflow_node_execution,
+                    )
+                    session.commit()
 
 
-                node_failed_response = self._workflow_node_finish_to_stream_response(
-                    event=event,
-                    task_id=self._application_generate_entity.task_id,
-                    workflow_node_execution=workflow_node_execution,
-                )
                 if node_failed_response:
                 if node_failed_response:
                     yield node_failed_response
                     yield node_failed_response
 
 
             elif isinstance(event, QueueParallelBranchRunStartedEvent):
             elif isinstance(event, QueueParallelBranchRunStartedEvent):
-                if not workflow_run:
+                if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
                     raise ValueError("workflow run not initialized.")
 
 
-                yield self._workflow_parallel_branch_start_to_stream_response(
-                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
-                )
+                with Session(db.engine) as session:
+                    workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
+                    parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response(
+                        session=session,
+                        task_id=self._application_generate_entity.task_id,
+                        workflow_run=workflow_run,
+                        event=event,
+                    )
+
+                yield parallel_start_resp
+
             elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
             elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
-                if not workflow_run:
+                if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
                     raise ValueError("workflow run not initialized.")
 
 
-                yield self._workflow_parallel_branch_finished_to_stream_response(
-                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
-                )
+                with Session(db.engine) as session:
+                    workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
+                    parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response(
+                        session=session,
+                        task_id=self._application_generate_entity.task_id,
+                        workflow_run=workflow_run,
+                        event=event,
+                    )
+
+                yield parallel_finish_resp
+
             elif isinstance(event, QueueIterationStartEvent):
             elif isinstance(event, QueueIterationStartEvent):
-                if not workflow_run:
+                if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
                     raise ValueError("workflow run not initialized.")
 
 
-                yield self._workflow_iteration_start_to_stream_response(
-                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
-                )
+                with Session(db.engine) as session:
+                    workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
+                    iter_start_resp = self._workflow_iteration_start_to_stream_response(
+                        session=session,
+                        task_id=self._application_generate_entity.task_id,
+                        workflow_run=workflow_run,
+                        event=event,
+                    )
+
+                yield iter_start_resp
+
             elif isinstance(event, QueueIterationNextEvent):
             elif isinstance(event, QueueIterationNextEvent):
-                if not workflow_run:
+                if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
                     raise ValueError("workflow run not initialized.")
 
 
-                yield self._workflow_iteration_next_to_stream_response(
-                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
-                )
+                with Session(db.engine) as session:
+                    workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
+                    iter_next_resp = self._workflow_iteration_next_to_stream_response(
+                        session=session,
+                        task_id=self._application_generate_entity.task_id,
+                        workflow_run=workflow_run,
+                        event=event,
+                    )
+
+                yield iter_next_resp
+
             elif isinstance(event, QueueIterationCompletedEvent):
             elif isinstance(event, QueueIterationCompletedEvent):
-                if not workflow_run:
+                if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
                     raise ValueError("workflow run not initialized.")
 
 
-                yield self._workflow_iteration_completed_to_stream_response(
-                    task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
-                )
+                with Session(db.engine) as session:
+                    workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
+                    iter_finish_resp = self._workflow_iteration_completed_to_stream_response(
+                        session=session,
+                        task_id=self._application_generate_entity.task_id,
+                        workflow_run=workflow_run,
+                        event=event,
+                    )
+
+                yield iter_finish_resp
+
             elif isinstance(event, QueueWorkflowSucceededEvent):
             elif isinstance(event, QueueWorkflowSucceededEvent):
-                if not workflow_run:
+                if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
                     raise ValueError("workflow run not initialized.")
-
                 if not graph_runtime_state:
                 if not graph_runtime_state:
                     raise ValueError("graph runtime state not initialized.")
                     raise ValueError("graph runtime state not initialized.")
 
 
                 with Session(db.engine) as session:
                 with Session(db.engine) as session:
                     workflow_run = self._handle_workflow_run_success(
                     workflow_run = self._handle_workflow_run_success(
                         session=session,
                         session=session,
-                        workflow_run=workflow_run,
+                        workflow_run_id=self._workflow_run_id,
                         start_at=graph_runtime_state.start_at,
                         start_at=graph_runtime_state.start_at,
                         total_tokens=graph_runtime_state.total_tokens,
                         total_tokens=graph_runtime_state.total_tokens,
                         total_steps=graph_runtime_state.node_run_steps,
                         total_steps=graph_runtime_state.node_run_steps,
@@ -378,18 +437,18 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                         workflow_run=workflow_run,
                         workflow_run=workflow_run,
                     )
                     )
                     session.commit()
                     session.commit()
+
                 yield workflow_finish_resp
                 yield workflow_finish_resp
             elif isinstance(event, QueueWorkflowPartialSuccessEvent):
             elif isinstance(event, QueueWorkflowPartialSuccessEvent):
-                if not workflow_run:
+                if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
                     raise ValueError("workflow run not initialized.")
-
                 if not graph_runtime_state:
                 if not graph_runtime_state:
                     raise ValueError("graph runtime state not initialized.")
                     raise ValueError("graph runtime state not initialized.")
 
 
                 with Session(db.engine) as session:
                 with Session(db.engine) as session:
                     workflow_run = self._handle_workflow_run_partial_success(
                     workflow_run = self._handle_workflow_run_partial_success(
                         session=session,
                         session=session,
-                        workflow_run=workflow_run,
+                        workflow_run_id=self._workflow_run_id,
                         start_at=graph_runtime_state.start_at,
                         start_at=graph_runtime_state.start_at,
                         total_tokens=graph_runtime_state.total_tokens,
                         total_tokens=graph_runtime_state.total_tokens,
                         total_steps=graph_runtime_state.node_run_steps,
                         total_steps=graph_runtime_state.node_run_steps,
@@ -409,15 +468,15 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
 
 
                 yield workflow_finish_resp
                 yield workflow_finish_resp
             elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
             elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
-                if not workflow_run:
+                if not self._workflow_run_id:
                     raise ValueError("workflow run not initialized.")
                     raise ValueError("workflow run not initialized.")
-
                 if not graph_runtime_state:
                 if not graph_runtime_state:
                     raise ValueError("graph runtime state not initialized.")
                     raise ValueError("graph runtime state not initialized.")
+
                 with Session(db.engine) as session:
                 with Session(db.engine) as session:
                     workflow_run = self._handle_workflow_run_failed(
                     workflow_run = self._handle_workflow_run_failed(
                         session=session,
                         session=session,
-                        workflow_run=workflow_run,
+                        workflow_run_id=self._workflow_run_id,
                         start_at=graph_runtime_state.start_at,
                         start_at=graph_runtime_state.start_at,
                         total_tokens=graph_runtime_state.total_tokens,
                         total_tokens=graph_runtime_state.total_tokens,
                         total_steps=graph_runtime_state.node_run_steps,
                         total_steps=graph_runtime_state.node_run_steps,
@@ -437,6 +496,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                         session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
                         session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
                     )
                     )
                     session.commit()
                     session.commit()
+
                 yield workflow_finish_resp
                 yield workflow_finish_resp
             elif isinstance(event, QueueTextChunkEvent):
             elif isinstance(event, QueueTextChunkEvent):
                 delta_text = event.text
                 delta_text = event.text

+ 94 - 172
api/core/app/task_pipeline/workflow_cycle_manage.py

@@ -46,7 +46,6 @@ from core.workflow.enums import SystemVariableKey
 from core.workflow.nodes import NodeType
 from core.workflow.nodes import NodeType
 from core.workflow.nodes.tool.entities import ToolNodeData
 from core.workflow.nodes.tool.entities import ToolNodeData
 from core.workflow.workflow_entry import WorkflowEntry
 from core.workflow.workflow_entry import WorkflowEntry
-from extensions.ext_database import db
 from models.account import Account
 from models.account import Account
 from models.enums import CreatedByRole, WorkflowRunTriggeredFrom
 from models.enums import CreatedByRole, WorkflowRunTriggeredFrom
 from models.model import EndUser
 from models.model import EndUser
@@ -66,7 +65,6 @@ class WorkflowCycleManage:
     _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
     _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
     _task_state: WorkflowTaskState
     _task_state: WorkflowTaskState
     _workflow_system_variables: dict[SystemVariableKey, Any]
     _workflow_system_variables: dict[SystemVariableKey, Any]
-    _wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
 
 
     def _handle_workflow_run_start(
     def _handle_workflow_run_start(
         self,
         self,
@@ -130,7 +128,7 @@ class WorkflowCycleManage:
         self,
         self,
         *,
         *,
         session: Session,
         session: Session,
-        workflow_run: WorkflowRun,
+        workflow_run_id: str,
         start_at: float,
         start_at: float,
         total_tokens: int,
         total_tokens: int,
         total_steps: int,
         total_steps: int,
@@ -148,7 +146,7 @@ class WorkflowCycleManage:
         :param conversation_id: conversation id
         :param conversation_id: conversation id
         :return:
         :return:
         """
         """
-        workflow_run = self._refetch_workflow_run(session=session, workflow_run_id=workflow_run.id)
+        workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id)
 
 
         outputs = WorkflowEntry.handle_special_values(outputs)
         outputs = WorkflowEntry.handle_special_values(outputs)
 
 
@@ -175,7 +173,7 @@ class WorkflowCycleManage:
         self,
         self,
         *,
         *,
         session: Session,
         session: Session,
-        workflow_run: WorkflowRun,
+        workflow_run_id: str,
         start_at: float,
         start_at: float,
         total_tokens: int,
         total_tokens: int,
         total_steps: int,
         total_steps: int,
@@ -184,18 +182,7 @@ class WorkflowCycleManage:
         conversation_id: Optional[str] = None,
         conversation_id: Optional[str] = None,
         trace_manager: Optional[TraceQueueManager] = None,
         trace_manager: Optional[TraceQueueManager] = None,
     ) -> WorkflowRun:
     ) -> WorkflowRun:
-        """
-        Workflow run success
-        :param workflow_run: workflow run
-        :param start_at: start time
-        :param total_tokens: total tokens
-        :param total_steps: total steps
-        :param outputs: outputs
-        :param conversation_id: conversation id
-        :return:
-        """
-        workflow_run = self._refetch_workflow_run(session=session, workflow_run_id=workflow_run.id)
-
+        workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id)
         outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
         outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
 
 
         workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value
         workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value
@@ -222,7 +209,7 @@ class WorkflowCycleManage:
         self,
         self,
         *,
         *,
         session: Session,
         session: Session,
-        workflow_run: WorkflowRun,
+        workflow_run_id: str,
         start_at: float,
         start_at: float,
         total_tokens: int,
         total_tokens: int,
         total_steps: int,
         total_steps: int,
@@ -242,7 +229,7 @@ class WorkflowCycleManage:
         :param error: error message
         :param error: error message
         :return:
         :return:
         """
         """
-        workflow_run = self._refetch_workflow_run(session=session, workflow_run_id=workflow_run.id)
+        workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id)
 
 
         workflow_run.status = status.value
         workflow_run.status = status.value
         workflow_run.error = error
         workflow_run.error = error
@@ -284,49 +271,41 @@ class WorkflowCycleManage:
         return workflow_run
         return workflow_run
 
 
     def _handle_node_execution_start(
     def _handle_node_execution_start(
-        self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
+        self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
     ) -> WorkflowNodeExecution:
     ) -> WorkflowNodeExecution:
-        # init workflow node execution
-
-        with Session(db.engine, expire_on_commit=False) as session:
-            workflow_node_execution = WorkflowNodeExecution()
-            workflow_node_execution.tenant_id = workflow_run.tenant_id
-            workflow_node_execution.app_id = workflow_run.app_id
-            workflow_node_execution.workflow_id = workflow_run.workflow_id
-            workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
-            workflow_node_execution.workflow_run_id = workflow_run.id
-            workflow_node_execution.predecessor_node_id = event.predecessor_node_id
-            workflow_node_execution.index = event.node_run_index
-            workflow_node_execution.node_execution_id = event.node_execution_id
-            workflow_node_execution.node_id = event.node_id
-            workflow_node_execution.node_type = event.node_type.value
-            workflow_node_execution.title = event.node_data.title
-            workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
-            workflow_node_execution.created_by_role = workflow_run.created_by_role
-            workflow_node_execution.created_by = workflow_run.created_by
-            workflow_node_execution.execution_metadata = json.dumps(
-                {
-                    NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
-                    NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
-                }
-            )
-            workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
-
-            session.add(workflow_node_execution)
-            session.commit()
-            session.refresh(workflow_node_execution)
+        workflow_node_execution = WorkflowNodeExecution()
+        workflow_node_execution.id = event.node_execution_id
+        workflow_node_execution.tenant_id = workflow_run.tenant_id
+        workflow_node_execution.app_id = workflow_run.app_id
+        workflow_node_execution.workflow_id = workflow_run.workflow_id
+        workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
+        workflow_node_execution.workflow_run_id = workflow_run.id
+        workflow_node_execution.predecessor_node_id = event.predecessor_node_id
+        workflow_node_execution.index = event.node_run_index
+        workflow_node_execution.node_execution_id = event.node_execution_id
+        workflow_node_execution.node_id = event.node_id
+        workflow_node_execution.node_type = event.node_type.value
+        workflow_node_execution.title = event.node_data.title
+        workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
+        workflow_node_execution.created_by_role = workflow_run.created_by_role
+        workflow_node_execution.created_by = workflow_run.created_by
+        workflow_node_execution.execution_metadata = json.dumps(
+            {
+                NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
+                NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
+            }
+        )
+        workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
 
 
-        self._wip_workflow_node_executions[workflow_node_execution.node_execution_id] = workflow_node_execution
+        session.add(workflow_node_execution)
         return workflow_node_execution
         return workflow_node_execution
 
 
-    def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
-        """
-        Workflow node execution success
-        :param event: queue node succeeded event
-        :return:
-        """
-        workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
-
+    def _handle_workflow_node_execution_success(
+        self, *, session: Session, event: QueueNodeSucceededEvent
+    ) -> WorkflowNodeExecution:
+        workflow_node_execution = self._get_workflow_node_execution(
+            session=session, node_execution_id=event.node_execution_id
+        )
         inputs = WorkflowEntry.handle_special_values(event.inputs)
         inputs = WorkflowEntry.handle_special_values(event.inputs)
         process_data = WorkflowEntry.handle_special_values(event.process_data)
         process_data = WorkflowEntry.handle_special_values(event.process_data)
         outputs = WorkflowEntry.handle_special_values(event.outputs)
         outputs = WorkflowEntry.handle_special_values(event.outputs)
@@ -336,20 +315,6 @@ class WorkflowCycleManage:
         finished_at = datetime.now(UTC).replace(tzinfo=None)
         finished_at = datetime.now(UTC).replace(tzinfo=None)
         elapsed_time = (finished_at - event.start_at).total_seconds()
         elapsed_time = (finished_at - event.start_at).total_seconds()
 
 
-        db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
-            {
-                WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.SUCCEEDED.value,
-                WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
-                WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None,
-                WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
-                WorkflowNodeExecution.execution_metadata: execution_metadata,
-                WorkflowNodeExecution.finished_at: finished_at,
-                WorkflowNodeExecution.elapsed_time: elapsed_time,
-            }
-        )
-
-        db.session.commit()
-        db.session.close()
         process_data = WorkflowEntry.handle_special_values(event.process_data)
         process_data = WorkflowEntry.handle_special_values(event.process_data)
 
 
         workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
         workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
@@ -360,19 +325,22 @@ class WorkflowCycleManage:
         workflow_node_execution.finished_at = finished_at
         workflow_node_execution.finished_at = finished_at
         workflow_node_execution.elapsed_time = elapsed_time
         workflow_node_execution.elapsed_time = elapsed_time
 
 
-        self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id)
-
         return workflow_node_execution
         return workflow_node_execution
 
 
     def _handle_workflow_node_execution_failed(
     def _handle_workflow_node_execution_failed(
-        self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent
+        self,
+        *,
+        session: Session,
+        event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent,
     ) -> WorkflowNodeExecution:
     ) -> WorkflowNodeExecution:
         """
         """
         Workflow node execution failed
         Workflow node execution failed
         :param event: queue node failed event
         :param event: queue node failed event
         :return:
         :return:
         """
         """
-        workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
+        workflow_node_execution = self._get_workflow_node_execution(
+            session=session, node_execution_id=event.node_execution_id
+        )
 
 
         inputs = WorkflowEntry.handle_special_values(event.inputs)
         inputs = WorkflowEntry.handle_special_values(event.inputs)
         process_data = WorkflowEntry.handle_special_values(event.process_data)
         process_data = WorkflowEntry.handle_special_values(event.process_data)
@@ -382,25 +350,6 @@ class WorkflowCycleManage:
         execution_metadata = (
         execution_metadata = (
             json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
             json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
         )
         )
-        db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
-            {
-                WorkflowNodeExecution.status: (
-                    WorkflowNodeExecutionStatus.FAILED.value
-                    if not isinstance(event, QueueNodeExceptionEvent)
-                    else WorkflowNodeExecutionStatus.EXCEPTION.value
-                ),
-                WorkflowNodeExecution.error: event.error,
-                WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
-                WorkflowNodeExecution.process_data: json.dumps(process_data) if process_data else None,
-                WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
-                WorkflowNodeExecution.finished_at: finished_at,
-                WorkflowNodeExecution.elapsed_time: elapsed_time,
-                WorkflowNodeExecution.execution_metadata: execution_metadata,
-            }
-        )
-
-        db.session.commit()
-        db.session.close()
         process_data = WorkflowEntry.handle_special_values(event.process_data)
         process_data = WorkflowEntry.handle_special_values(event.process_data)
         workflow_node_execution.status = (
         workflow_node_execution.status = (
             WorkflowNodeExecutionStatus.FAILED.value
             WorkflowNodeExecutionStatus.FAILED.value
@@ -415,12 +364,10 @@ class WorkflowCycleManage:
         workflow_node_execution.elapsed_time = elapsed_time
         workflow_node_execution.elapsed_time = elapsed_time
         workflow_node_execution.execution_metadata = execution_metadata
         workflow_node_execution.execution_metadata = execution_metadata
 
 
-        self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id)
-
         return workflow_node_execution
         return workflow_node_execution
 
 
     def _handle_workflow_node_execution_retried(
     def _handle_workflow_node_execution_retried(
-        self, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
+        self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
     ) -> WorkflowNodeExecution:
     ) -> WorkflowNodeExecution:
         """
         """
         Workflow node execution failed
         Workflow node execution failed
@@ -444,6 +391,7 @@ class WorkflowCycleManage:
         execution_metadata = json.dumps(merged_metadata)
         execution_metadata = json.dumps(merged_metadata)
 
 
         workflow_node_execution = WorkflowNodeExecution()
         workflow_node_execution = WorkflowNodeExecution()
+        workflow_node_execution.id = event.node_execution_id
         workflow_node_execution.tenant_id = workflow_run.tenant_id
         workflow_node_execution.tenant_id = workflow_run.tenant_id
         workflow_node_execution.app_id = workflow_run.app_id
         workflow_node_execution.app_id = workflow_run.app_id
         workflow_node_execution.workflow_id = workflow_run.workflow_id
         workflow_node_execution.workflow_id = workflow_run.workflow_id
@@ -466,10 +414,7 @@ class WorkflowCycleManage:
         workflow_node_execution.execution_metadata = execution_metadata
         workflow_node_execution.execution_metadata = execution_metadata
         workflow_node_execution.index = event.node_run_index
         workflow_node_execution.index = event.node_run_index
 
 
-        db.session.add(workflow_node_execution)
-        db.session.commit()
-        db.session.refresh(workflow_node_execution)
-
+        session.add(workflow_node_execution)
         return workflow_node_execution
         return workflow_node_execution
 
 
     #################################################
     #################################################
@@ -547,17 +492,20 @@ class WorkflowCycleManage:
         )
         )
 
 
     def _workflow_node_start_to_stream_response(
     def _workflow_node_start_to_stream_response(
-        self, event: QueueNodeStartedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution
+        self,
+        *,
+        session: Session,
+        event: QueueNodeStartedEvent,
+        task_id: str,
+        workflow_node_execution: WorkflowNodeExecution,
     ) -> Optional[NodeStartStreamResponse]:
     ) -> Optional[NodeStartStreamResponse]:
-        """
-        Workflow node start to stream response.
-        :param event: queue node started event
-        :param task_id: task id
-        :param workflow_node_execution: workflow node execution
-        :return:
-        """
+        # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
+        _ = session
+
         if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
         if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
             return None
             return None
+        if not workflow_node_execution.workflow_run_id:
+            return None
 
 
         response = NodeStartStreamResponse(
         response = NodeStartStreamResponse(
             task_id=task_id,
             task_id=task_id,
@@ -593,6 +541,8 @@ class WorkflowCycleManage:
 
 
     def _workflow_node_finish_to_stream_response(
     def _workflow_node_finish_to_stream_response(
         self,
         self,
+        *,
+        session: Session,
         event: QueueNodeSucceededEvent
         event: QueueNodeSucceededEvent
         | QueueNodeFailedEvent
         | QueueNodeFailedEvent
         | QueueNodeInIterationFailedEvent
         | QueueNodeInIterationFailedEvent
@@ -600,15 +550,14 @@ class WorkflowCycleManage:
         task_id: str,
         task_id: str,
         workflow_node_execution: WorkflowNodeExecution,
         workflow_node_execution: WorkflowNodeExecution,
     ) -> Optional[NodeFinishStreamResponse]:
     ) -> Optional[NodeFinishStreamResponse]:
-        """
-        Workflow node finish to stream response.
-        :param event: queue node succeeded or failed event
-        :param task_id: task id
-        :param workflow_node_execution: workflow node execution
-        :return:
-        """
+        # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
+        _ = session
         if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
         if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
             return None
             return None
+        if not workflow_node_execution.workflow_run_id:
+            return None
+        if not workflow_node_execution.finished_at:
+            return None
 
 
         return NodeFinishStreamResponse(
         return NodeFinishStreamResponse(
             task_id=task_id,
             task_id=task_id,
@@ -640,19 +589,20 @@ class WorkflowCycleManage:
 
 
     def _workflow_node_retry_to_stream_response(
     def _workflow_node_retry_to_stream_response(
         self,
         self,
+        *,
+        session: Session,
         event: QueueNodeRetryEvent,
         event: QueueNodeRetryEvent,
         task_id: str,
         task_id: str,
         workflow_node_execution: WorkflowNodeExecution,
         workflow_node_execution: WorkflowNodeExecution,
     ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
     ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
-        """
-        Workflow node finish to stream response.
-        :param event: queue node succeeded or failed event
-        :param task_id: task id
-        :param workflow_node_execution: workflow node execution
-        :return:
-        """
+        # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
+        _ = session
         if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
         if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
             return None
             return None
+        if not workflow_node_execution.workflow_run_id:
+            return None
+        if not workflow_node_execution.finished_at:
+            return None
 
 
         return NodeRetryStreamResponse(
         return NodeRetryStreamResponse(
             task_id=task_id,
             task_id=task_id,
@@ -684,15 +634,10 @@ class WorkflowCycleManage:
         )
         )
 
 
     def _workflow_parallel_branch_start_to_stream_response(
     def _workflow_parallel_branch_start_to_stream_response(
-        self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
+        self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
     ) -> ParallelBranchStartStreamResponse:
     ) -> ParallelBranchStartStreamResponse:
-        """
-        Workflow parallel branch start to stream response
-        :param task_id: task id
-        :param workflow_run: workflow run
-        :param event: parallel branch run started event
-        :return:
-        """
+        # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
+        _ = session
         return ParallelBranchStartStreamResponse(
         return ParallelBranchStartStreamResponse(
             task_id=task_id,
             task_id=task_id,
             workflow_run_id=workflow_run.id,
             workflow_run_id=workflow_run.id,
@@ -708,17 +653,14 @@ class WorkflowCycleManage:
 
 
     def _workflow_parallel_branch_finished_to_stream_response(
     def _workflow_parallel_branch_finished_to_stream_response(
         self,
         self,
+        *,
+        session: Session,
         task_id: str,
         task_id: str,
         workflow_run: WorkflowRun,
         workflow_run: WorkflowRun,
         event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
         event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
     ) -> ParallelBranchFinishedStreamResponse:
     ) -> ParallelBranchFinishedStreamResponse:
-        """
-        Workflow parallel branch finished to stream response
-        :param task_id: task id
-        :param workflow_run: workflow run
-        :param event: parallel branch run succeeded or failed event
-        :return:
-        """
+        # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
+        _ = session
         return ParallelBranchFinishedStreamResponse(
         return ParallelBranchFinishedStreamResponse(
             task_id=task_id,
             task_id=task_id,
             workflow_run_id=workflow_run.id,
             workflow_run_id=workflow_run.id,
@@ -735,15 +677,10 @@ class WorkflowCycleManage:
         )
         )
 
 
     def _workflow_iteration_start_to_stream_response(
     def _workflow_iteration_start_to_stream_response(
-        self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent
+        self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent
     ) -> IterationNodeStartStreamResponse:
     ) -> IterationNodeStartStreamResponse:
-        """
-        Workflow iteration start to stream response
-        :param task_id: task id
-        :param workflow_run: workflow run
-        :param event: iteration start event
-        :return:
-        """
+        # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
+        _ = session
         return IterationNodeStartStreamResponse(
         return IterationNodeStartStreamResponse(
             task_id=task_id,
             task_id=task_id,
             workflow_run_id=workflow_run.id,
             workflow_run_id=workflow_run.id,
@@ -762,15 +699,10 @@ class WorkflowCycleManage:
         )
         )
 
 
     def _workflow_iteration_next_to_stream_response(
     def _workflow_iteration_next_to_stream_response(
-        self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent
+        self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent
     ) -> IterationNodeNextStreamResponse:
     ) -> IterationNodeNextStreamResponse:
-        """
-        Workflow iteration next to stream response
-        :param task_id: task id
-        :param workflow_run: workflow run
-        :param event: iteration next event
-        :return:
-        """
+        # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
+        _ = session
         return IterationNodeNextStreamResponse(
         return IterationNodeNextStreamResponse(
             task_id=task_id,
             task_id=task_id,
             workflow_run_id=workflow_run.id,
             workflow_run_id=workflow_run.id,
@@ -791,15 +723,10 @@ class WorkflowCycleManage:
         )
         )
 
 
     def _workflow_iteration_completed_to_stream_response(
     def _workflow_iteration_completed_to_stream_response(
-        self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent
+        self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent
     ) -> IterationNodeCompletedStreamResponse:
     ) -> IterationNodeCompletedStreamResponse:
-        """
-        Workflow iteration completed to stream response
-        :param task_id: task id
-        :param workflow_run: workflow run
-        :param event: iteration completed event
-        :return:
-        """
+        # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
+        _ = session
         return IterationNodeCompletedStreamResponse(
         return IterationNodeCompletedStreamResponse(
             task_id=task_id,
             task_id=task_id,
             workflow_run_id=workflow_run.id,
             workflow_run_id=workflow_run.id,
@@ -883,7 +810,7 @@ class WorkflowCycleManage:
 
 
         return None
         return None
 
 
-    def _refetch_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
+    def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
         """
         """
         Refetch workflow run
         Refetch workflow run
         :param workflow_run_id: workflow run id
         :param workflow_run_id: workflow run id
@@ -896,14 +823,9 @@ class WorkflowCycleManage:
 
 
         return workflow_run
         return workflow_run
 
 
-    def _refetch_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
-        """
-        Refetch workflow node execution
-        :param node_execution_id: workflow node execution id
-        :return:
-        """
-        workflow_node_execution = self._wip_workflow_node_executions.get(node_execution_id)
-
+    def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution:
+        stmt = select(WorkflowNodeExecution).where(WorkflowNodeExecution.id == node_execution_id)
+        workflow_node_execution = session.scalar(stmt)
         if not workflow_node_execution:
         if not workflow_node_execution:
             raise WorkflowNodeExecutionNotFoundError(node_execution_id)
             raise WorkflowNodeExecutionNotFoundError(node_execution_id)
 
 

+ 26 - 26
api/models/workflow.py

@@ -400,11 +400,11 @@ class WorkflowRun(db.Model):  # type: ignore[name-defined]
     type: Mapped[str] = mapped_column(db.String(255))
     type: Mapped[str] = mapped_column(db.String(255))
     triggered_from: Mapped[str] = mapped_column(db.String(255))
     triggered_from: Mapped[str] = mapped_column(db.String(255))
     version: Mapped[str] = mapped_column(db.String(255))
     version: Mapped[str] = mapped_column(db.String(255))
-    graph: Mapped[str] = mapped_column(db.Text)
-    inputs: Mapped[str] = mapped_column(db.Text)
+    graph: Mapped[Optional[str]] = mapped_column(db.Text)
+    inputs: Mapped[Optional[str]] = mapped_column(db.Text)
     status: Mapped[str] = mapped_column(db.String(255))  # running, succeeded, failed, stopped, partial-succeeded
     status: Mapped[str] = mapped_column(db.String(255))  # running, succeeded, failed, stopped, partial-succeeded
     outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}")
     outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}")
-    error: Mapped[str] = mapped_column(db.Text)
+    error: Mapped[Optional[str]] = mapped_column(db.Text)
     elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0"))
     elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0"))
     total_tokens: Mapped[int] = mapped_column(server_default=db.text("0"))
     total_tokens: Mapped[int] = mapped_column(server_default=db.text("0"))
     total_steps = db.Column(db.Integer, server_default=db.text("0"))
     total_steps = db.Column(db.Integer, server_default=db.text("0"))
@@ -609,29 +609,29 @@ class WorkflowNodeExecution(db.Model):  # type: ignore[name-defined]
         ),
         ),
     )
     )
 
 
-    id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
-    tenant_id = db.Column(StringUUID, nullable=False)
-    app_id = db.Column(StringUUID, nullable=False)
-    workflow_id = db.Column(StringUUID, nullable=False)
-    triggered_from = db.Column(db.String(255), nullable=False)
-    workflow_run_id = db.Column(StringUUID)
-    index = db.Column(db.Integer, nullable=False)
-    predecessor_node_id = db.Column(db.String(255))
-    node_execution_id = db.Column(db.String(255), nullable=True)
-    node_id = db.Column(db.String(255), nullable=False)
-    node_type = db.Column(db.String(255), nullable=False)
-    title = db.Column(db.String(255), nullable=False)
-    inputs = db.Column(db.Text)
-    process_data = db.Column(db.Text)
-    outputs = db.Column(db.Text)
-    status = db.Column(db.String(255), nullable=False)
-    error = db.Column(db.Text)
-    elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0"))
-    execution_metadata = db.Column(db.Text)
-    created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
-    created_by_role = db.Column(db.String(255), nullable=False)
-    created_by = db.Column(StringUUID, nullable=False)
-    finished_at = db.Column(db.DateTime)
+    id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+    tenant_id: Mapped[str] = mapped_column(StringUUID)
+    app_id: Mapped[str] = mapped_column(StringUUID)
+    workflow_id: Mapped[str] = mapped_column(StringUUID)
+    triggered_from: Mapped[str] = mapped_column(db.String(255))
+    workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID)
+    index: Mapped[int] = mapped_column(db.Integer)
+    predecessor_node_id: Mapped[Optional[str]] = mapped_column(db.String(255))
+    node_execution_id: Mapped[Optional[str]] = mapped_column(db.String(255))
+    node_id: Mapped[str] = mapped_column(db.String(255))
+    node_type: Mapped[str] = mapped_column(db.String(255))
+    title: Mapped[str] = mapped_column(db.String(255))
+    inputs: Mapped[Optional[str]] = mapped_column(db.Text)
+    process_data: Mapped[Optional[str]] = mapped_column(db.Text)
+    outputs: Mapped[Optional[str]] = mapped_column(db.Text)
+    status: Mapped[str] = mapped_column(db.String(255))
+    error: Mapped[Optional[str]] = mapped_column(db.Text)
+    elapsed_time: Mapped[float] = mapped_column(db.Float, server_default=db.text("0"))
+    execution_metadata: Mapped[Optional[str]] = mapped_column(db.Text)
+    created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp())
+    created_by_role: Mapped[str] = mapped_column(db.String(255))
+    created_by: Mapped[str] = mapped_column(StringUUID)
+    finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
 
 
     @property
     @property
     def created_by_account(self):
     def created_by_account(self):

+ 2 - 0
api/services/workflow_service.py

@@ -3,6 +3,7 @@ import time
 from collections.abc import Sequence
 from collections.abc import Sequence
 from datetime import UTC, datetime
 from datetime import UTC, datetime
 from typing import Any, Optional, cast
 from typing import Any, Optional, cast
+from uuid import uuid4
 
 
 from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
 from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
 from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
 from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
@@ -277,6 +278,7 @@ class WorkflowService:
             error = e.error
             error = e.error
 
 
         workflow_node_execution = WorkflowNodeExecution()
         workflow_node_execution = WorkflowNodeExecution()
+        workflow_node_execution.id = str(uuid4())
         workflow_node_execution.tenant_id = app_model.tenant_id
         workflow_node_execution.tenant_id = app_model.tenant_id
         workflow_node_execution.app_id = app_model.id
         workflow_node_execution.app_id = app_model.id
         workflow_node_execution.workflow_id = draft_workflow.id
         workflow_node_execution.workflow_id = draft_workflow.id