Kaynağa Gözat

chore(api): improve type hints for BaseNode and its subclasses (#15826)

QuantumGhost 1 ay önce
ebeveyn
işleme
23ed3a520b

+ 2 - 2
api/core/workflow/nodes/base/node.py

@@ -22,7 +22,7 @@ GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData)
 
 
 class BaseNode(Generic[GenericNodeData]):
-    _node_data_cls: type[BaseNodeData]
+    _node_data_cls: type[GenericNodeData]
     _node_type: NodeType
 
     def __init__(
@@ -57,7 +57,7 @@ class BaseNode(Generic[GenericNodeData]):
         self.node_id = node_id
 
         node_data = self._node_data_cls.model_validate(config.get("data", {}))
-        self.node_data = cast(GenericNodeData, node_data)
+        self.node_data = node_data
 
     @abstractmethod
     def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:

+ 1 - 1
api/core/workflow/nodes/iteration/iteration_start_node.py

@@ -5,7 +5,7 @@ from core.workflow.nodes.iteration.entities import IterationStartNodeData
 from models.workflow import WorkflowNodeExecutionStatus
 
 
-class IterationStartNode(BaseNode):
+class IterationStartNode(BaseNode[IterationStartNodeData]):
     """
     Iteration Start Node.
     """

+ 1 - 1
api/core/workflow/nodes/loop/loop_start_node.py

@@ -5,7 +5,7 @@ from core.workflow.nodes.loop.entities import LoopStartNodeData
 from models.workflow import WorkflowNodeExecutionStatus
 
 
-class LoopStartNode(BaseNode):
+class LoopStartNode(BaseNode[LoopStartNodeData]):
     """
     Loop Start Node.
     """

+ 2 - 2
api/core/workflow/nodes/variable_assigner/v1/node.py

@@ -1,6 +1,6 @@
 from core.variables import SegmentType, Variable
 from core.workflow.entities.node_entities import NodeRunResult
-from core.workflow.nodes.base import BaseNode, BaseNodeData
+from core.workflow.nodes.base import BaseNode
 from core.workflow.nodes.enums import NodeType
 from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
 from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
@@ -11,7 +11,7 @@ from .node_data import VariableAssignerData, WriteMode
 
 
 class VariableAssignerNode(BaseNode[VariableAssignerData]):
-    _node_data_cls: type[BaseNodeData] = VariableAssignerData
+    _node_data_cls = VariableAssignerData
     _node_type = NodeType.VARIABLE_ASSIGNER
 
     def _run(self) -> NodeRunResult: