Pārlūkot izejas kodu

chore(list_operator): refine exception handling for error specificity (#10206)

-LAN- 5 mēneši atpakaļ
vecāks
revīzija
1432c268a8

+ 16 - 0
api/core/workflow/nodes/list_operator/exc.py

@@ -0,0 +1,16 @@
+class ListOperatorError(ValueError):
+    """Base class for all ListOperator errors."""
+
+    pass
+
+
+class InvalidFilterValueError(ListOperatorError):
+    pass
+
+
+class InvalidKeyError(ListOperatorError):
+    pass
+
+
+class InvalidConditionError(ListOperatorError):
+    pass

+ 97 - 58
api/core/workflow/nodes/list_operator/node.py

@@ -1,5 +1,5 @@
 from collections.abc import Callable, Sequence
-from typing import Literal
+from typing import Literal, Union
 
 from core.file import File
 from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
@@ -9,6 +9,7 @@ from core.workflow.nodes.enums import NodeType
 from models.workflow import WorkflowNodeExecutionStatus
 
 from .entities import ListOperatorNodeData
+from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
 
 
 class ListOperatorNode(BaseNode[ListOperatorNodeData]):
@@ -26,7 +27,17 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
             return NodeRunResult(
                 status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
             )
-        if variable.value and not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
+        if not variable.value:
+            inputs = {"variable": []}
+            process_data = {"variable": []}
+            outputs = {"result": [], "first_record": None, "last_record": None}
+            return NodeRunResult(
+                status=WorkflowNodeExecutionStatus.SUCCEEDED,
+                inputs=inputs,
+                process_data=process_data,
+                outputs=outputs,
+            )
+        if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
             error_message = (
                 f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
                 "or ArrayStringSegment"
@@ -36,70 +47,98 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
             )
 
         if isinstance(variable, ArrayFileSegment):
+            inputs = {"variable": [item.to_dict() for item in variable.value]}
             process_data["variable"] = [item.to_dict() for item in variable.value]
         else:
+            inputs = {"variable": variable.value}
             process_data["variable"] = variable.value
 
-        # Filter
-        if self.node_data.filter_by.enabled:
-            for condition in self.node_data.filter_by.conditions:
-                if isinstance(variable, ArrayStringSegment):
-                    if not isinstance(condition.value, str):
-                        raise ValueError(f"Invalid filter value: {condition.value}")
-                    value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
-                    filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value)
-                    result = list(filter(filter_func, variable.value))
-                    variable = variable.model_copy(update={"value": result})
-                elif isinstance(variable, ArrayNumberSegment):
-                    if not isinstance(condition.value, str):
-                        raise ValueError(f"Invalid filter value: {condition.value}")
-                    value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
-                    filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value))
-                    result = list(filter(filter_func, variable.value))
-                    variable = variable.model_copy(update={"value": result})
-                elif isinstance(variable, ArrayFileSegment):
-                    if isinstance(condition.value, str):
-                        value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
-                    else:
-                        value = condition.value
-                    filter_func = _get_file_filter_func(
-                        key=condition.key,
-                        condition=condition.comparison_operator,
-                        value=value,
-                    )
-                    result = list(filter(filter_func, variable.value))
-                    variable = variable.model_copy(update={"value": result})
-
-        # Order
-        if self.node_data.order_by.enabled:
+        try:
+            # Filter
+            if self.node_data.filter_by.enabled:
+                variable = self._apply_filter(variable)
+
+            # Order
+            if self.node_data.order_by.enabled:
+                variable = self._apply_order(variable)
+
+            # Slice
+            if self.node_data.limit.enabled:
+                variable = self._apply_slice(variable)
+
+            outputs = {
+                "result": variable.value,
+                "first_record": variable.value[0] if variable.value else None,
+                "last_record": variable.value[-1] if variable.value else None,
+            }
+            return NodeRunResult(
+                status=WorkflowNodeExecutionStatus.SUCCEEDED,
+                inputs=inputs,
+                process_data=process_data,
+                outputs=outputs,
+            )
+        except ListOperatorError as e:
+            return NodeRunResult(
+                status=WorkflowNodeExecutionStatus.FAILED,
+                error=str(e),
+                inputs=inputs,
+                process_data=process_data,
+                outputs=outputs,
+            )
+
+    def _apply_filter(
+        self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
+    ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
+        for condition in self.node_data.filter_by.conditions:
             if isinstance(variable, ArrayStringSegment):
-                result = _order_string(order=self.node_data.order_by.value, array=variable.value)
+                if not isinstance(condition.value, str):
+                    raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
+                value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
+                filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value)
+                result = list(filter(filter_func, variable.value))
                 variable = variable.model_copy(update={"value": result})
             elif isinstance(variable, ArrayNumberSegment):
-                result = _order_number(order=self.node_data.order_by.value, array=variable.value)
+                if not isinstance(condition.value, str):
+                    raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
+                value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
+                filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value))
+                result = list(filter(filter_func, variable.value))
                 variable = variable.model_copy(update={"value": result})
             elif isinstance(variable, ArrayFileSegment):
-                result = _order_file(
-                    order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value
+                if isinstance(condition.value, str):
+                    value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
+                else:
+                    value = condition.value
+                filter_func = _get_file_filter_func(
+                    key=condition.key,
+                    condition=condition.comparison_operator,
+                    value=value,
                 )
+                result = list(filter(filter_func, variable.value))
                 variable = variable.model_copy(update={"value": result})
+        return variable
 
-        # Slice
-        if self.node_data.limit.enabled:
-            result = variable.value[: self.node_data.limit.size]
+    def _apply_order(
+        self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
+    ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
+        if isinstance(variable, ArrayStringSegment):
+            result = _order_string(order=self.node_data.order_by.value, array=variable.value)
+            variable = variable.model_copy(update={"value": result})
+        elif isinstance(variable, ArrayNumberSegment):
+            result = _order_number(order=self.node_data.order_by.value, array=variable.value)
+            variable = variable.model_copy(update={"value": result})
+        elif isinstance(variable, ArrayFileSegment):
+            result = _order_file(
+                order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value
+            )
             variable = variable.model_copy(update={"value": result})
+        return variable
 
-        outputs = {
-            "result": variable.value,
-            "first_record": variable.value[0] if variable.value else None,
-            "last_record": variable.value[-1] if variable.value else None,
-        }
-        return NodeRunResult(
-            status=WorkflowNodeExecutionStatus.SUCCEEDED,
-            inputs=inputs,
-            process_data=process_data,
-            outputs=outputs,
-        )
+    def _apply_slice(
+        self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
+    ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
+        result = variable.value[: self.node_data.limit.size]
+        return variable.model_copy(update={"value": result})
 
 
 def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]:
@@ -107,7 +146,7 @@ def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]:
         case "size":
             return lambda x: x.size
         case _:
-            raise ValueError(f"Invalid key: {key}")
+            raise InvalidKeyError(f"Invalid key: {key}")
 
 
 def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
@@ -125,7 +164,7 @@ def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
         case "url":
             return lambda x: x.remote_url or ""
         case _:
-            raise ValueError(f"Invalid key: {key}")
+            raise InvalidKeyError(f"Invalid key: {key}")
 
 
 def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bool]:
@@ -151,7 +190,7 @@ def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bo
         case "not empty":
             return lambda x: x != ""
         case _:
-            raise ValueError(f"Invalid condition: {condition}")
+            raise InvalidConditionError(f"Invalid condition: {condition}")
 
 
 def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callable[[str], bool]:
@@ -161,7 +200,7 @@ def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callab
         case "not in":
             return lambda x: not _in(value)(x)
         case _:
-            raise ValueError(f"Invalid condition: {condition}")
+            raise InvalidConditionError(f"Invalid condition: {condition}")
 
 
 def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[int | float], bool]:
@@ -179,7 +218,7 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[
         case "≥":
             return _ge(value)
         case _:
-            raise ValueError(f"Invalid condition: {condition}")
+            raise InvalidConditionError(f"Invalid condition: {condition}")
 
 
 def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
@@ -193,7 +232,7 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str
         extract_func = _get_file_extract_number_func(key=key)
         return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x))
     else:
-        raise ValueError(f"Invalid key: {key}")
+        raise InvalidKeyError(f"Invalid key: {key}")
 
 
 def _contains(value: str):