Selaa lähdekoodia

fix: WorkflowNodeExecution.created_at may be earlier than WorkflowRun.created_at (#11070)

-LAN- 5 kuukautta sitten
vanhempi
commit
98d85e6b74
2 muutettua tiedostoa jossa 30 lisäystä ja 29 poistoa
  1. 25 24
      api/core/app/task_pipeline/workflow_cycle_manage.py
  2. 5 5
      api/models/workflow.py

+ 25 - 24
api/core/app/task_pipeline/workflow_cycle_manage.py

@@ -3,6 +3,7 @@ import time
 from collections.abc import Mapping, Sequence
 from datetime import UTC, datetime
 from typing import Any, Optional, Union, cast
+from uuid import uuid4
 
 from sqlalchemy.orm import Session
 
@@ -80,38 +81,38 @@ class WorkflowCycleManage:
 
             inputs[f"sys.{key.value}"] = value
 
-        inputs = WorkflowEntry.handle_special_values(inputs)
-
         triggered_from = (
             WorkflowRunTriggeredFrom.DEBUGGING
             if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
             else WorkflowRunTriggeredFrom.APP_RUN
         )
 
+        # handle special values
+        inputs = WorkflowEntry.handle_special_values(inputs)
+
         # init workflow run
-        workflow_run = WorkflowRun()
-        workflow_run_id = self._workflow_system_variables[SystemVariableKey.WORKFLOW_RUN_ID]
-        if workflow_run_id:
-            workflow_run.id = workflow_run_id
-        workflow_run.tenant_id = self._workflow.tenant_id
-        workflow_run.app_id = self._workflow.app_id
-        workflow_run.sequence_number = new_sequence_number
-        workflow_run.workflow_id = self._workflow.id
-        workflow_run.type = self._workflow.type
-        workflow_run.triggered_from = triggered_from.value
-        workflow_run.version = self._workflow.version
-        workflow_run.graph = self._workflow.graph
-        workflow_run.inputs = json.dumps(inputs)
-        workflow_run.status = WorkflowRunStatus.RUNNING.value
-        workflow_run.created_by_role = (
-            CreatedByRole.ACCOUNT.value if isinstance(self._user, Account) else CreatedByRole.END_USER.value
-        )
-        workflow_run.created_by = self._user.id
+        with Session(db.engine, expire_on_commit=False) as session:
+            workflow_run = WorkflowRun()
+            system_id = self._workflow_system_variables[SystemVariableKey.WORKFLOW_RUN_ID]
+            workflow_run.id = system_id or str(uuid4())
+            workflow_run.tenant_id = self._workflow.tenant_id
+            workflow_run.app_id = self._workflow.app_id
+            workflow_run.sequence_number = new_sequence_number
+            workflow_run.workflow_id = self._workflow.id
+            workflow_run.type = self._workflow.type
+            workflow_run.triggered_from = triggered_from.value
+            workflow_run.version = self._workflow.version
+            workflow_run.graph = self._workflow.graph
+            workflow_run.inputs = json.dumps(inputs)
+            workflow_run.status = WorkflowRunStatus.RUNNING
+            workflow_run.created_by_role = (
+                CreatedByRole.ACCOUNT if isinstance(self._user, Account) else CreatedByRole.END_USER
+            )
+            workflow_run.created_by = self._user.id
+            workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None)
 
-        db.session.add(workflow_run)
-        db.session.commit()
-        db.session.refresh(workflow_run)
-        db.session.close()
+            session.add(workflow_run)
+            session.commit()
 
         return workflow_run
 

+ 5 - 5
api/models/workflow.py

@@ -1,7 +1,7 @@
 import json
 from collections.abc import Mapping, Sequence
 from datetime import UTC, datetime
-from enum import Enum
+from enum import Enum, StrEnum
 from typing import Any, Optional, Union
 
 import sqlalchemy as sa
@@ -314,7 +314,7 @@ class Workflow(db.Model):
         )
 
 
-class WorkflowRunStatus(Enum):
+class WorkflowRunStatus(StrEnum):
     """
     Workflow Run Status Enum
     """
@@ -393,13 +393,13 @@ class WorkflowRun(db.Model):
     version = db.Column(db.String(255), nullable=False)
     graph = db.Column(db.Text)
     inputs = db.Column(db.Text)
-    status = db.Column(db.String(255), nullable=False)
-    outputs: Mapped[str] = db.Column(db.Text)
+    status = db.Column(db.String(255), nullable=False)  # running, succeeded, failed, stopped
+    outputs: Mapped[str] = mapped_column(sa.Text, default="{}")
     error = db.Column(db.Text)
     elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0"))
     total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
     total_steps = db.Column(db.Integer, server_default=db.text("0"))
-    created_by_role = db.Column(db.String(255), nullable=False)
+    created_by_role = db.Column(db.String(255), nullable=False)  # account, end_user
     created_by = db.Column(StringUUID, nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
     finished_at = db.Column(db.DateTime)