Forráskód Böngészése

fix(workflow): improve database session handling and variable management (#9581)

-LAN- 6 hónapja
szülő
commit
c063617553

+ 24 - 22
api/core/app/task_pipeline/workflow_cycle_manage.py

@@ -4,6 +4,8 @@ from collections.abc import Mapping, Sequence
 from datetime import datetime, timezone
 from typing import Any, Optional, Union, cast
 
+from sqlalchemy.orm import Session
+
 from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
 from core.app.entities.queue_entities import (
     QueueIterationCompletedEvent,
@@ -232,30 +234,30 @@ class WorkflowCycleManage:
         self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
     ) -> WorkflowNodeExecution:
         # init workflow node execution
-        workflow_node_execution = WorkflowNodeExecution()
-        workflow_node_execution.tenant_id = workflow_run.tenant_id
-        workflow_node_execution.app_id = workflow_run.app_id
-        workflow_node_execution.workflow_id = workflow_run.workflow_id
-        workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
-        workflow_node_execution.workflow_run_id = workflow_run.id
-        workflow_node_execution.predecessor_node_id = event.predecessor_node_id
-        workflow_node_execution.index = event.node_run_index
-        workflow_node_execution.node_execution_id = event.node_execution_id
-        workflow_node_execution.node_id = event.node_id
-        workflow_node_execution.node_type = event.node_type.value
-        workflow_node_execution.title = event.node_data.title
-        workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
-        workflow_node_execution.created_by_role = workflow_run.created_by_role
-        workflow_node_execution.created_by = workflow_run.created_by
-        workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
-
-        db.session.add(workflow_node_execution)
-        db.session.commit()
-        db.session.refresh(workflow_node_execution)
-        db.session.close()
 
-        self._wip_workflow_node_executions[workflow_node_execution.node_execution_id] = workflow_node_execution
+        with Session(db.engine, expire_on_commit=False) as session:
+            workflow_node_execution = WorkflowNodeExecution()
+            workflow_node_execution.tenant_id = workflow_run.tenant_id
+            workflow_node_execution.app_id = workflow_run.app_id
+            workflow_node_execution.workflow_id = workflow_run.workflow_id
+            workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
+            workflow_node_execution.workflow_run_id = workflow_run.id
+            workflow_node_execution.predecessor_node_id = event.predecessor_node_id
+            workflow_node_execution.index = event.node_run_index
+            workflow_node_execution.node_execution_id = event.node_execution_id
+            workflow_node_execution.node_id = event.node_id
+            workflow_node_execution.node_type = event.node_type.value
+            workflow_node_execution.title = event.node_data.title
+            workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
+            workflow_node_execution.created_by_role = workflow_run.created_by_role
+            workflow_node_execution.created_by = workflow_run.created_by
+            workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
+
+            session.add(workflow_node_execution)
+            session.commit()
+            session.refresh(workflow_node_execution)
 
+        self._wip_workflow_node_executions[workflow_node_execution.node_execution_id] = workflow_node_execution
         return workflow_node_execution
 
     def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:

+ 0 - 3
api/core/workflow/entities/variable_pool.py

@@ -96,9 +96,6 @@ class VariablePool(BaseModel):
         if len(selector) < 2:
             raise ValueError("Invalid selector")
 
-        if value is None:
-            return
-
         if isinstance(value, Segment):
             v = value
         else:

+ 4 - 4
api/core/workflow/nodes/llm/node.py

@@ -320,11 +320,11 @@ class LLMNode(BaseNode[LLMNodeData]):
             variable_selectors = variable_template_parser.extract_variable_selectors()
 
         for variable_selector in variable_selectors:
-            variable_value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
+            variable_value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
             if variable_value is None:
                 raise ValueError(f"Variable {variable_selector.variable} not found")
 
-            inputs[variable_selector.variable] = variable_value
+            inputs[variable_selector.variable] = variable_value.to_object()
 
         memory = node_data.memory
         if memory and memory.query_prompt_template:
@@ -332,11 +332,11 @@ class LLMNode(BaseNode[LLMNodeData]):
                 template=memory.query_prompt_template
             ).extract_variable_selectors()
             for variable_selector in query_variable_selectors:
-                variable_value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
+                variable_value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
                 if variable_value is None:
                     raise ValueError(f"Variable {variable_selector.variable} not found")
 
-                inputs[variable_selector.variable] = variable_value
+                inputs[variable_selector.variable] = variable_value.to_object()
 
         return inputs