ソースを参照

refactor(api/core/app/apps/base_app_generator.py): improve input validation and sanitization in BaseAppGenerator (#5866)

-LAN- 9 ヶ月 前
コミット
66a62e6c13
2 ファイル変更52 行追加44 行削除
  1. 4 0
      api/core/app/app_config/entities.py
  2. 48 44
      api/core/app/apps/base_app_generator.py

+ 4 - 0
api/core/app/app_config/entities.py

@@ -114,6 +114,10 @@ class VariableEntity(BaseModel):
     default: Optional[str] = None
     hint: Optional[str] = None
 
+    @property
+    def name(self) -> str:
+        return self.variable
+
 
 class ExternalDataVariableEntity(BaseModel):
     """

+ 48 - 44
api/core/app/apps/base_app_generator.py

@@ -1,52 +1,56 @@
+from collections.abc import Mapping
+from typing import Any, Optional
+
 from core.app.app_config.entities import AppConfig, VariableEntity
 
 
 class BaseAppGenerator:
-    def _get_cleaned_inputs(self, user_inputs: dict, app_config: AppConfig):
-        if user_inputs is None:
-            user_inputs = {}
-
-        filtered_inputs = {}
-
+    def _get_cleaned_inputs(self, user_inputs: Optional[Mapping[str, Any]], app_config: AppConfig) -> Mapping[str, Any]:
+        user_inputs = user_inputs or {}
         # Filter input variables from form configuration, handle required fields, default values, and option values
         variables = app_config.variables
-        for variable_config in variables:
-            variable = variable_config.variable
-
-            if (variable not in user_inputs
-                    or user_inputs[variable] is None
-                    or (isinstance(user_inputs[variable], str) and user_inputs[variable] == '')):
-                if variable_config.required:
-                    raise ValueError(f"{variable} is required in input form")
-                else:
-                    filtered_inputs[variable] = variable_config.default if variable_config.default is not None else ""
-                    continue
-
-            value = user_inputs[variable]
-
-            if value is not None:
-                if variable_config.type != VariableEntity.Type.NUMBER and not isinstance(value, str):
-                    raise ValueError(f"{variable} in input form must be a string")
-                elif variable_config.type == VariableEntity.Type.NUMBER and isinstance(value, str):
-                    if '.' in value:
-                        value = float(value)
-                    else:
-                        value = int(value)
-
-            if variable_config.type == VariableEntity.Type.SELECT:
-                options = variable_config.options if variable_config.options is not None else []
-                if value not in options:
-                    raise ValueError(f"{variable} in input form must be one of the following: {options}")
-            elif variable_config.type in [VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH]:
-                if variable_config.max_length is not None:
-                    max_length = variable_config.max_length
-                    if len(value) > max_length:
-                        raise ValueError(f'{variable} in input form must be less than {max_length} characters')
-
-            if value and isinstance(value, str):
-                filtered_inputs[variable] = value.replace('\x00', '')
-            else:
-                filtered_inputs[variable] = value if value is not None else None
-
+        filtered_inputs = {var.name: self._validate_input(inputs=user_inputs, var=var) for var in variables}
+        filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()}
         return filtered_inputs
 
+    def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
+        user_input_value = inputs.get(var.name)
+        if var.required and not user_input_value:
+            raise ValueError(f'{var.name} is required in input form')
+        if not var.required and not user_input_value:
+            # TODO: should we return None here if the default value is None?
+            return var.default or ''
+        if (
+            var.type
+            in (
+                VariableEntity.Type.TEXT_INPUT,
+                VariableEntity.Type.SELECT,
+                VariableEntity.Type.PARAGRAPH,
+            )
+            and user_input_value
+            and not isinstance(user_input_value, str)
+        ):
+            raise ValueError(f"(type '{var.type}') {var.name} in input form must be a string")
+        if var.type == VariableEntity.Type.NUMBER and isinstance(user_input_value, str):
+            # may raise ValueError if user_input_value is not a valid number
+            try:
+                if '.' in user_input_value:
+                    return float(user_input_value)
+                else:
+                    return int(user_input_value)
+            except ValueError:
+                raise ValueError(f"{var.name} in input form must be a valid number")
+        if var.type == VariableEntity.Type.SELECT:
+            options = var.options or []
+            if user_input_value not in options:
+                raise ValueError(f'{var.name} in input form must be one of the following: {options}')
+        elif var.type in (VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH):
+            if var.max_length and user_input_value and len(user_input_value) > var.max_length:
+                raise ValueError(f'{var.name} in input form must be less than {var.max_length} characters')
+
+        return user_input_value
+
+    def _sanitize_value(self, value: Any) -> Any:
+        if isinstance(value, str):
+            return value.replace('\x00', '')
+        return value