Parcourir la source

chore: optimize the trace ops slow queries on node executions. (#9282)

takatost il y a 6 mois
Parent
commit
23ce1fb1ba

+ 42 - 13
api/core/app/task_pipeline/workflow_cycle_manage.py

@@ -266,20 +266,35 @@ class WorkflowCycleManage:
 
         inputs = WorkflowEntry.handle_special_values(event.inputs)
         outputs = WorkflowEntry.handle_special_values(event.outputs)
-
-        workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
-        workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
-        workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
-        workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
-        workflow_node_execution.execution_metadata = (
+        execution_metadata = (
             json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
         )
-        workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
-        workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
+        finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
+        elapsed_time = (finished_at - event.start_at).total_seconds()
+
+        db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
+            {
+                WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.SUCCEEDED.value,
+                WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
+                WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None,
+                WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
+                WorkflowNodeExecution.execution_metadata: execution_metadata,
+                WorkflowNodeExecution.finished_at: finished_at,
+                WorkflowNodeExecution.elapsed_time: elapsed_time,
+            }
+        )
 
         db.session.commit()
         db.session.close()
 
+        workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
+        workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
+        workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
+        workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
+        workflow_node_execution.execution_metadata = execution_metadata
+        workflow_node_execution.finished_at = finished_at
+        workflow_node_execution.elapsed_time = elapsed_time
+
         self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id)
 
         return workflow_node_execution
@@ -294,17 +309,31 @@ class WorkflowCycleManage:
 
         inputs = WorkflowEntry.handle_special_values(event.inputs)
         outputs = WorkflowEntry.handle_special_values(event.outputs)
+        finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
+        elapsed_time = (finished_at - event.start_at).total_seconds()
+
+        db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
+            {
+                WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
+                WorkflowNodeExecution.error: event.error,
+                WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
+                WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None,
+                WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
+                WorkflowNodeExecution.finished_at: finished_at,
+                WorkflowNodeExecution.elapsed_time: elapsed_time,
+            }
+        )
+
+        db.session.commit()
+        db.session.close()
 
         workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
         workflow_node_execution.error = event.error
-        workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
         workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
         workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
         workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
-        workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
-
-        db.session.commit()
-        db.session.close()
+        workflow_node_execution.finished_at = finished_at
+        workflow_node_execution.elapsed_time = elapsed_time
 
         self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id)
 

+ 25 - 16
api/core/ops/langfuse_trace/langfuse_trace.py

@@ -110,26 +110,35 @@ class LangFuseDataTrace(BaseTraceInstance):
             self.add_trace(langfuse_trace_data=trace_data)
 
         # through workflow_run_id get all_nodes_execution
-        workflow_nodes_executions = (
-            db.session.query(
-                WorkflowNodeExecution.id,
-                WorkflowNodeExecution.tenant_id,
-                WorkflowNodeExecution.app_id,
-                WorkflowNodeExecution.title,
-                WorkflowNodeExecution.node_type,
-                WorkflowNodeExecution.status,
-                WorkflowNodeExecution.inputs,
-                WorkflowNodeExecution.outputs,
-                WorkflowNodeExecution.created_at,
-                WorkflowNodeExecution.elapsed_time,
-                WorkflowNodeExecution.process_data,
-                WorkflowNodeExecution.execution_metadata,
-            )
+        workflow_nodes_execution_id_records = (
+            db.session.query(WorkflowNodeExecution.id)
             .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
             .all()
         )
 
-        for node_execution in workflow_nodes_executions:
+        for node_execution_id_record in workflow_nodes_execution_id_records:
+            node_execution = (
+                db.session.query(
+                    WorkflowNodeExecution.id,
+                    WorkflowNodeExecution.tenant_id,
+                    WorkflowNodeExecution.app_id,
+                    WorkflowNodeExecution.title,
+                    WorkflowNodeExecution.node_type,
+                    WorkflowNodeExecution.status,
+                    WorkflowNodeExecution.inputs,
+                    WorkflowNodeExecution.outputs,
+                    WorkflowNodeExecution.created_at,
+                    WorkflowNodeExecution.elapsed_time,
+                    WorkflowNodeExecution.process_data,
+                    WorkflowNodeExecution.execution_metadata,
+                )
+                .filter(WorkflowNodeExecution.id == node_execution_id_record.id)
+                .first()
+            )
+
+            if not node_execution:
+                continue
+
             node_execution_id = node_execution.id
             tenant_id = node_execution.tenant_id
             app_id = node_execution.app_id

+ 25 - 16
api/core/ops/langsmith_trace/langsmith_trace.py

@@ -100,26 +100,35 @@ class LangSmithDataTrace(BaseTraceInstance):
         self.add_run(langsmith_run)
 
         # through workflow_run_id get all_nodes_execution
-        workflow_nodes_executions = (
-            db.session.query(
-                WorkflowNodeExecution.id,
-                WorkflowNodeExecution.tenant_id,
-                WorkflowNodeExecution.app_id,
-                WorkflowNodeExecution.title,
-                WorkflowNodeExecution.node_type,
-                WorkflowNodeExecution.status,
-                WorkflowNodeExecution.inputs,
-                WorkflowNodeExecution.outputs,
-                WorkflowNodeExecution.created_at,
-                WorkflowNodeExecution.elapsed_time,
-                WorkflowNodeExecution.process_data,
-                WorkflowNodeExecution.execution_metadata,
-            )
+        workflow_nodes_execution_id_records = (
+            db.session.query(WorkflowNodeExecution.id)
             .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
             .all()
         )
 
-        for node_execution in workflow_nodes_executions:
+        for node_execution_id_record in workflow_nodes_execution_id_records:
+            node_execution = (
+                db.session.query(
+                    WorkflowNodeExecution.id,
+                    WorkflowNodeExecution.tenant_id,
+                    WorkflowNodeExecution.app_id,
+                    WorkflowNodeExecution.title,
+                    WorkflowNodeExecution.node_type,
+                    WorkflowNodeExecution.status,
+                    WorkflowNodeExecution.inputs,
+                    WorkflowNodeExecution.outputs,
+                    WorkflowNodeExecution.created_at,
+                    WorkflowNodeExecution.elapsed_time,
+                    WorkflowNodeExecution.process_data,
+                    WorkflowNodeExecution.execution_metadata,
+                )
+                .filter(WorkflowNodeExecution.id == node_execution_id_record.id)
+                .first()
+            )
+
+            if not node_execution:
+                continue
+
             node_execution_id = node_execution.id
             tenant_id = node_execution.tenant_id
             app_id = node_execution.app_id