Kaynağa Gözat

feat(app_dsl_service): enhance error handling and DSL version management (#10108)

-LAN- 5 ay önce
ebeveyn
işleme
e5397c5ec2

+ 1 - 1
api/models/model.py

@@ -396,7 +396,7 @@ class AppModelConfig(db.Model):
             "file_upload": self.file_upload_dict,
         }
 
-    def from_model_config_dict(self, model_config: dict):
+    def from_model_config_dict(self, model_config: Mapping[str, Any]):
         self.opening_statement = model_config.get("opening_statement")
         self.suggested_questions = (
             json.dumps(model_config["suggested_questions"]) if model_config.get("suggested_questions") else None

+ 3 - 0
api/services/app_dsl_service/__init__.py

@@ -0,0 +1,3 @@
+from .service import AppDslService
+
+__all__ = ["AppDslService"]

+ 34 - 0
api/services/app_dsl_service/exc.py

@@ -0,0 +1,34 @@
+class DSLVersionNotSupportedError(ValueError):
+    """Raised when the imported DSL version is not supported by the current Dify version."""
+
+
+class InvalidYAMLFormatError(ValueError):
+    """Raised when the provided YAML format is invalid."""
+
+
+class MissingAppDataError(ValueError):
+    """Raised when the app data is missing in the provided DSL."""
+
+
+class InvalidAppModeError(ValueError):
+    """Raised when the app mode is invalid."""
+
+
+class MissingWorkflowDataError(ValueError):
+    """Raised when the workflow data is missing in the provided DSL."""
+
+
+class MissingModelConfigError(ValueError):
+    """Raised when the model config data is missing in the provided DSL."""
+
+
+class FileSizeLimitExceededError(ValueError):
+    """Raised when the file size exceeds the allowed limit."""
+
+
+class EmptyContentError(ValueError):
+    """Raised when the content fetched from the URL is empty."""
+
+
+class ContentDecodingError(ValueError):
+    """Raised when there is an error decoding the content."""

+ 99 - 62
api/services/app_dsl_service.py → api/services/app_dsl_service/service.py

@@ -1,8 +1,11 @@
 import logging
+from collections.abc import Mapping
+from typing import Any
 
-import httpx
-import yaml  # type: ignore
+import yaml
+from packaging import version
 
+from core.helper import ssrf_proxy
 from events.app_event import app_model_config_was_updated, app_was_created
 from extensions.ext_database import db
 from factories import variable_factory
@@ -11,6 +14,18 @@ from models.model import App, AppMode, AppModelConfig
 from models.workflow import Workflow
 from services.workflow_service import WorkflowService
 
+from .exc import (
+    ContentDecodingError,
+    DSLVersionNotSupportedError,
+    EmptyContentError,
+    FileSizeLimitExceededError,
+    InvalidAppModeError,
+    InvalidYAMLFormatError,
+    MissingAppDataError,
+    MissingModelConfigError,
+    MissingWorkflowDataError,
+)
+
 logger = logging.getLogger(__name__)
 
 current_dsl_version = "0.1.2"
@@ -30,32 +45,21 @@ class AppDslService:
         :param args: request args
         :param account: Account instance
         """
-        try:
-            max_size = 10 * 1024 * 1024  # 10MB
-            timeout = httpx.Timeout(10.0)
-            with httpx.stream("GET", url.strip(), follow_redirects=True, timeout=timeout) as response:
-                response.raise_for_status()
-                total_size = 0
-                content = b""
-                for chunk in response.iter_bytes():
-                    total_size += len(chunk)
-                    if total_size > max_size:
-                        raise ValueError("File size exceeds the limit of 10MB")
-                    content += chunk
-        except httpx.HTTPStatusError as http_err:
-            raise ValueError(f"HTTP error occurred: {http_err}")
-        except httpx.RequestError as req_err:
-            raise ValueError(f"Request error occurred: {req_err}")
-        except Exception as e:
-            raise ValueError(f"Failed to fetch DSL from URL: {e}")
+        max_size = 10 * 1024 * 1024  # 10MB
+        response = ssrf_proxy.get(url.strip(), follow_redirects=True, timeout=(10, 10))
+        response.raise_for_status()
+        content = response.content
+
+        if len(content) > max_size:
+            raise FileSizeLimitExceededError("File size exceeds the limit of 10MB")
 
         if not content:
-            raise ValueError("Empty content from url")
+            raise EmptyContentError("Empty content from url")
 
         try:
             data = content.decode("utf-8")
         except UnicodeDecodeError as e:
-            raise ValueError(f"Error decoding content: {e}")
+            raise ContentDecodingError(f"Error decoding content: {e}")
 
         return cls.import_and_create_new_app(tenant_id, data, args, account)
 
@@ -71,14 +75,14 @@ class AppDslService:
         try:
             import_data = yaml.safe_load(data)
         except yaml.YAMLError:
-            raise ValueError("Invalid YAML format in data argument.")
+            raise InvalidYAMLFormatError("Invalid YAML format in data argument.")
 
         # check or repair dsl version
-        import_data = cls._check_or_fix_dsl(import_data)
+        import_data = _check_or_fix_dsl(import_data)
 
         app_data = import_data.get("app")
         if not app_data:
-            raise ValueError("Missing app in data argument")
+            raise MissingAppDataError("Missing app in data argument")
 
         # get app basic info
         name = args.get("name") or app_data.get("name")
@@ -90,11 +94,18 @@ class AppDslService:
 
         # import dsl and create app
         app_mode = AppMode.value_of(app_data.get("mode"))
+
         if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
+            workflow_data = import_data.get("workflow")
+            if not workflow_data or not isinstance(workflow_data, dict):
+                raise MissingWorkflowDataError(
+                    "Missing workflow in data argument when app mode is advanced-chat or workflow"
+                )
+
             app = cls._import_and_create_new_workflow_based_app(
                 tenant_id=tenant_id,
                 app_mode=app_mode,
-                workflow_data=import_data.get("workflow"),
+                workflow_data=workflow_data,
                 account=account,
                 name=name,
                 description=description,
@@ -104,10 +115,16 @@ class AppDslService:
                 use_icon_as_answer_icon=use_icon_as_answer_icon,
             )
         elif app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION}:
+            model_config = import_data.get("model_config")
+            if not model_config or not isinstance(model_config, dict):
+                raise MissingModelConfigError(
+                    "Missing model_config in data argument when app mode is chat, agent-chat or completion"
+                )
+
             app = cls._import_and_create_new_model_config_based_app(
                 tenant_id=tenant_id,
                 app_mode=app_mode,
-                model_config_data=import_data.get("model_config"),
+                model_config_data=model_config,
                 account=account,
                 name=name,
                 description=description,
@@ -117,7 +134,7 @@ class AppDslService:
                 use_icon_as_answer_icon=use_icon_as_answer_icon,
             )
         else:
-            raise ValueError("Invalid app mode")
+            raise InvalidAppModeError("Invalid app mode")
 
         return app
 
@@ -132,26 +149,32 @@ class AppDslService:
         try:
             import_data = yaml.safe_load(data)
         except yaml.YAMLError:
-            raise ValueError("Invalid YAML format in data argument.")
+            raise InvalidYAMLFormatError("Invalid YAML format in data argument.")
 
         # check or repair dsl version
-        import_data = cls._check_or_fix_dsl(import_data)
+        import_data = _check_or_fix_dsl(import_data)
 
         app_data = import_data.get("app")
         if not app_data:
-            raise ValueError("Missing app in data argument")
+            raise MissingAppDataError("Missing app in data argument")
 
         # import dsl and overwrite app
         app_mode = AppMode.value_of(app_data.get("mode"))
         if app_mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
-            raise ValueError("Only support import workflow in advanced-chat or workflow app.")
+            raise InvalidAppModeError("Only support import workflow in advanced-chat or workflow app.")
 
         if app_data.get("mode") != app_model.mode:
             raise ValueError(f"App mode {app_data.get('mode')} is not matched with current app mode {app_mode.value}")
 
+        workflow_data = import_data.get("workflow")
+        if not workflow_data or not isinstance(workflow_data, dict):
+            raise MissingWorkflowDataError(
+                "Missing workflow in data argument when app mode is advanced-chat or workflow"
+            )
+
         return cls._import_and_overwrite_workflow_based_app(
             app_model=app_model,
-            workflow_data=import_data.get("workflow"),
+            workflow_data=workflow_data,
             account=account,
         )
 
@@ -186,35 +209,12 @@ class AppDslService:
 
         return yaml.dump(export_data, allow_unicode=True)
 
-    @classmethod
-    def _check_or_fix_dsl(cls, import_data: dict) -> dict:
-        """
-        Check or fix dsl
-
-        :param import_data: import data
-        """
-        if not import_data.get("version"):
-            import_data["version"] = "0.1.0"
-
-        if not import_data.get("kind") or import_data.get("kind") != "app":
-            import_data["kind"] = "app"
-
-        if import_data.get("version") != current_dsl_version:
-            # Currently only one DSL version, so no difference checks or compatibility fixes will be performed.
-            logger.warning(
-                f"DSL version {import_data.get('version')} is not compatible "
-                f"with current version {current_dsl_version}, related to "
-                f"Dify version {dsl_to_dify_version_mapping.get(current_dsl_version)}."
-            )
-
-        return import_data
-
     @classmethod
     def _import_and_create_new_workflow_based_app(
         cls,
         tenant_id: str,
         app_mode: AppMode,
-        workflow_data: dict,
+        workflow_data: Mapping[str, Any],
         account: Account,
         name: str,
         description: str,
@@ -238,7 +238,9 @@ class AppDslService:
         :param use_icon_as_answer_icon: use app icon as answer icon
         """
         if not workflow_data:
-            raise ValueError("Missing workflow in data argument when app mode is advanced-chat or workflow")
+            raise MissingWorkflowDataError(
+                "Missing workflow in data argument when app mode is advanced-chat or workflow"
+            )
 
         app = cls._create_app(
             tenant_id=tenant_id,
@@ -277,7 +279,7 @@ class AppDslService:
 
     @classmethod
     def _import_and_overwrite_workflow_based_app(
-        cls, app_model: App, workflow_data: dict, account: Account
+        cls, app_model: App, workflow_data: Mapping[str, Any], account: Account
     ) -> Workflow:
         """
         Import app dsl and overwrite workflow based app
@@ -287,7 +289,9 @@ class AppDslService:
         :param account: Account instance
         """
         if not workflow_data:
-            raise ValueError("Missing workflow in data argument when app mode is advanced-chat or workflow")
+            raise MissingWorkflowDataError(
+                "Missing workflow in data argument when app mode is advanced-chat or workflow"
+            )
 
         # fetch draft workflow by app_model
         workflow_service = WorkflowService()
@@ -323,7 +327,7 @@ class AppDslService:
         cls,
         tenant_id: str,
         app_mode: AppMode,
-        model_config_data: dict,
+        model_config_data: Mapping[str, Any],
         account: Account,
         name: str,
         description: str,
@@ -345,7 +349,9 @@ class AppDslService:
         :param icon_background: app icon background
         """
         if not model_config_data:
-            raise ValueError("Missing model_config in data argument when app mode is chat, agent-chat or completion")
+            raise MissingModelConfigError(
+                "Missing model_config in data argument when app mode is chat, agent-chat or completion"
+            )
 
         app = cls._create_app(
             tenant_id=tenant_id,
@@ -448,3 +454,34 @@ class AppDslService:
             raise ValueError("Missing app configuration, please check.")
 
         export_data["model_config"] = app_model_config.to_dict()
+
+
+def _check_or_fix_dsl(import_data: dict[str, Any]) -> Mapping[str, Any]:
+    """
+    Check or fix dsl
+
+    :param import_data: import data
+    :raises DSLVersionNotSupportedError: if the imported DSL version is newer than the current version
+    """
+    if not import_data.get("version"):
+        import_data["version"] = "0.1.0"
+
+    if not import_data.get("kind") or import_data.get("kind") != "app":
+        import_data["kind"] = "app"
+
+    imported_version = import_data.get("version")
+    if imported_version != current_dsl_version:
+        if imported_version and version.parse(imported_version) > version.parse(current_dsl_version):
+            raise DSLVersionNotSupportedError(
+                f"The imported DSL version {imported_version} is newer than "
+                f"the current supported version {current_dsl_version}. "
+                f"Please upgrade your Dify instance to import this configuration."
+            )
+        else:
+            logger.warning(
+                f"DSL version {imported_version} is older than "
+                f"the current version {current_dsl_version}. "
+                f"This may cause compatibility issues."
+            )
+
+    return import_data

+ 41 - 0
api/tests/unit_tests/services/app_dsl_service/test_app_dsl_service.py

@@ -0,0 +1,41 @@
+import pytest
+from packaging import version
+
+from services.app_dsl_service import AppDslService
+from services.app_dsl_service.exc import DSLVersionNotSupportedError
+from services.app_dsl_service.service import _check_or_fix_dsl, current_dsl_version
+
+
+class TestAppDSLService:
+    def test_check_or_fix_dsl_missing_version(self):
+        import_data = {}
+        result = _check_or_fix_dsl(import_data)
+        assert result["version"] == "0.1.0"
+        assert result["kind"] == "app"
+
+    def test_check_or_fix_dsl_missing_kind(self):
+        import_data = {"version": "0.1.0"}
+        result = _check_or_fix_dsl(import_data)
+        assert result["kind"] == "app"
+
+    def test_check_or_fix_dsl_older_version(self):
+        import_data = {"version": "0.0.9", "kind": "app"}
+        result = _check_or_fix_dsl(import_data)
+        assert result["version"] == "0.0.9"
+
+    def test_check_or_fix_dsl_current_version(self):
+        import_data = {"version": current_dsl_version, "kind": "app"}
+        result = _check_or_fix_dsl(import_data)
+        assert result["version"] == current_dsl_version
+
+    def test_check_or_fix_dsl_newer_version(self):
+        current_version = version.parse(current_dsl_version)
+        newer_version = f"{current_version.major}.{current_version.minor + 1}.0"
+        import_data = {"version": newer_version, "kind": "app"}
+        with pytest.raises(DSLVersionNotSupportedError):
+            _check_or_fix_dsl(import_data)
+
+    def test_check_or_fix_dsl_invalid_kind(self):
+        import_data = {"version": current_dsl_version, "kind": "invalid"}
+        result = _check_or_fix_dsl(import_data)
+        assert result["kind"] == "app"