Selaa lähdekoodia

chore: use cache instead of re-querying node record during workflow execution (#9280)

takatost 6 kuukautta sitten
vanhempi
commit
29188e0562

+ 3 - 0
api/core/app/apps/advanced_chat/generate_task_pipeline.py

@@ -56,6 +56,7 @@ from models.account import Account
 from models.model import Conversation, EndUser, Message
 from models.workflow import (
     Workflow,
+    WorkflowNodeExecution,
     WorkflowRunStatus,
 )
 
@@ -72,6 +73,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
     _workflow: Workflow
     _user: Union[Account, EndUser]
     _workflow_system_variables: dict[SystemVariableKey, Any]
+    _wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
 
     def __init__(
         self,
@@ -115,6 +117,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
         }
 
         self._task_state = WorkflowTaskState()
+        self._wip_workflow_node_executions = {}
 
         self._conversation_name_generate_thread = None
 

+ 3 - 0
api/core/app/apps/workflow/generate_task_pipeline.py

@@ -52,6 +52,7 @@ from models.workflow import (
     Workflow,
     WorkflowAppLog,
     WorkflowAppLogCreatedFrom,
+    WorkflowNodeExecution,
     WorkflowRun,
     WorkflowRunStatus,
 )
@@ -69,6 +70,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
     _task_state: WorkflowTaskState
     _application_generate_entity: WorkflowAppGenerateEntity
     _workflow_system_variables: dict[SystemVariableKey, Any]
+    _wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
 
     def __init__(
         self,
@@ -103,6 +105,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
         }
 
         self._task_state = WorkflowTaskState()
+        self._wip_workflow_node_executions = {}
 
     def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
         """

+ 8 - 13
api/core/app/task_pipeline/workflow_cycle_manage.py

@@ -57,6 +57,7 @@ class WorkflowCycleManage:
     _user: Union[Account, EndUser]
     _task_state: WorkflowTaskState
     _workflow_system_variables: dict[SystemVariableKey, Any]
+    _wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
 
     def _handle_workflow_run_start(self) -> WorkflowRun:
         max_sequence = (
@@ -251,6 +252,8 @@ class WorkflowCycleManage:
         db.session.refresh(workflow_node_execution)
         db.session.close()
 
+        self._wip_workflow_node_executions[workflow_node_execution.node_execution_id] = workflow_node_execution
+
         return workflow_node_execution
 
     def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
@@ -275,9 +278,10 @@ class WorkflowCycleManage:
         workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
 
         db.session.commit()
-        db.session.refresh(workflow_node_execution)
         db.session.close()
 
+        self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id)
+
         return workflow_node_execution
 
     def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> WorkflowNodeExecution:
@@ -300,9 +304,10 @@ class WorkflowCycleManage:
         workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
 
         db.session.commit()
-        db.session.refresh(workflow_node_execution)
         db.session.close()
 
+        self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id)
+
         return workflow_node_execution
 
     #################################################
@@ -678,17 +683,7 @@ class WorkflowCycleManage:
         :param node_execution_id: workflow node execution id
         :return:
         """
-        workflow_node_execution = (
-            db.session.query(WorkflowNodeExecution)
-            .filter(
-                WorkflowNodeExecution.tenant_id == self._application_generate_entity.app_config.tenant_id,
-                WorkflowNodeExecution.app_id == self._application_generate_entity.app_config.app_id,
-                WorkflowNodeExecution.workflow_id == self._workflow.id,
-                WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
-                WorkflowNodeExecution.node_execution_id == node_execution_id,
-            )
-            .first()
-        )
+        workflow_node_execution = self._wip_workflow_node_executions.get(node_execution_id)
 
         if not workflow_node_execution:
             raise Exception(f"Workflow node execution not found: {node_execution_id}")