Joe 8 ماه پیش
والد
کامیت
741c548f3c

+ 6 - 4
api/controllers/web/passport.py

@@ -9,21 +9,23 @@ from controllers.web.error import WebSSOAuthRequiredError
 from extensions.ext_database import db
 from libs.passport import PassportService
 from models.model import App, EndUser, Site
+from services.enterprise.enterprise_service import EnterpriseService
 from services.feature_service import FeatureService
 
 
 class PassportResource(Resource):
     """Base resource for passport."""
     def get(self):
-
         system_features = FeatureService.get_system_features()
-        if system_features.sso_enforced_for_web:
-            raise WebSSOAuthRequiredError()
-
         app_code = request.headers.get('X-App-Code')
         if app_code is None:
             raise Unauthorized('X-App-Code header is missing.')
 
+        if system_features.sso_enforced_for_web:
+            app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get('enabled', False)
+            if app_web_sso_enabled:
+                raise WebSSOAuthRequiredError()
+        
         # get site from db and check if it is normal
         site = db.session.query(Site).filter(
             Site.code == app_code,

+ 15 - 8
api/controllers/web/wraps.py

@@ -8,6 +8,7 @@ from controllers.web.error import WebSSOAuthRequiredError
 from extensions.ext_database import db
 from libs.passport import PassportService
 from models.model import App, EndUser, Site
+from services.enterprise.enterprise_service import EnterpriseService
 from services.feature_service import FeatureService
 
 
@@ -26,7 +27,7 @@ def validate_jwt_token(view=None):
 
 def decode_jwt_token():
     system_features = FeatureService.get_system_features()
-
+    app_code = request.headers.get('X-App-Code')
     try:
         auth_header = request.headers.get('Authorization')
         if auth_header is None:
@@ -54,25 +55,31 @@ def decode_jwt_token():
         if not end_user:
             raise NotFound()
 
-        _validate_web_sso_token(decoded, system_features)
+        _validate_web_sso_token(decoded, system_features, app_code)
 
         return app_model, end_user
     except Unauthorized as e:
         if system_features.sso_enforced_for_web:
-            raise WebSSOAuthRequiredError()
+            app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get('enabled', False)
+            if app_web_sso_enabled:
+                raise WebSSOAuthRequiredError()
 
         raise Unauthorized(e.description)
 
 
-def _validate_web_sso_token(decoded, system_features):
+def _validate_web_sso_token(decoded, system_features, app_code):
+    app_web_sso_enabled = False
+    
     # Check if SSO is enforced for web, and if the token source is not SSO, raise an error and redirect to SSO login
     if system_features.sso_enforced_for_web:
-        source = decoded.get('token_source')
-        if not source or source != 'sso':
-            raise WebSSOAuthRequiredError()
+        app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get('enabled', False)
+        if app_web_sso_enabled:
+            source = decoded.get('token_source')
+            if not source or source != 'sso':
+                raise WebSSOAuthRequiredError()
 
     # Check if SSO is not enforced for web, and if the token source is SSO, raise an error and redirect to normal passport login
-    if not system_features.sso_enforced_for_web:
+    if not system_features.sso_enforced_for_web or not app_web_sso_enabled:
         source = decoded.get('token_source')
         if source and source == 'sso':
             raise Unauthorized('sso token expired.')

+ 4 - 0
api/services/enterprise/enterprise_service.py

@@ -6,3 +6,7 @@ class EnterpriseService:
     @classmethod
     def get_info(cls):
         return EnterpriseRequest.send_request('GET', '/info')
+
+    @classmethod
+    def get_app_web_sso_enabled(cls, app_code):
+        return EnterpriseRequest.send_request('GET', f'/app-sso-setting?appCode={app_code}')

+ 2 - 1
api/services/feature_service.py

@@ -41,7 +41,7 @@ class SystemFeatureModel(BaseModel):
     sso_enforced_for_signin_protocol: str = ''
     sso_enforced_for_web: bool = False
     sso_enforced_for_web_protocol: str = ''
-
+    enable_web_sso_switch_component: bool = False
 
 class FeatureService:
 
@@ -61,6 +61,7 @@ class FeatureService:
         system_features = SystemFeatureModel()
 
         if dify_config.ENTERPRISE_ENABLED:
+            system_features.enable_web_sso_switch_component = True
             cls._fulfill_params_from_enterprise(system_features)
 
         return system_features