Parcourir la source

Remove useless code (#4416)

Garfield Dai il y a 11 mois
Parent
commit
dd94931116

+ 0 - 3
api/controllers/console/__init__.py

@@ -37,9 +37,6 @@ from .billing import billing
 # Import datasets controllers
 from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing
 
-# Import enterprise controllers
-from .enterprise import enterprise_sso
-
 # Import explore controllers
 from .explore import (
     audio,

+ 0 - 0
api/controllers/console/enterprise/__init__.py


+ 0 - 59
api/controllers/console/enterprise/enterprise_sso.py

@@ -1,59 +0,0 @@
-from flask import current_app, redirect
-from flask_restful import Resource, reqparse
-
-from controllers.console import api
-from controllers.console.setup import setup_required
-from services.enterprise.enterprise_sso_service import EnterpriseSSOService
-
-
-class EnterpriseSSOSamlLogin(Resource):
-
-    @setup_required
-    def get(self):
-        return EnterpriseSSOService.get_sso_saml_login()
-
-
-class EnterpriseSSOSamlAcs(Resource):
-
-    @setup_required
-    def post(self):
-        parser = reqparse.RequestParser()
-        parser.add_argument('SAMLResponse', type=str, required=True, location='form')
-        args = parser.parse_args()
-        saml_response = args['SAMLResponse']
-
-        try:
-            token = EnterpriseSSOService.post_sso_saml_acs(saml_response)
-            return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}')
-        except Exception as e:
-            return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}')
-
-
-class EnterpriseSSOOidcLogin(Resource):
-
-    @setup_required
-    def get(self):
-        return EnterpriseSSOService.get_sso_oidc_login()
-
-
-class EnterpriseSSOOidcCallback(Resource):
-
-    @setup_required
-    def get(self):
-        parser = reqparse.RequestParser()
-        parser.add_argument('state', type=str, required=True, location='args')
-        parser.add_argument('code', type=str, required=True, location='args')
-        parser.add_argument('oidc-state', type=str, required=True, location='cookies')
-        args = parser.parse_args()
-
-        try:
-            token = EnterpriseSSOService.get_sso_oidc_callback(args)
-            return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}')
-        except Exception as e:
-            return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}')
-
-
-api.add_resource(EnterpriseSSOSamlLogin, '/enterprise/sso/saml/login')
-api.add_resource(EnterpriseSSOSamlAcs, '/enterprise/sso/saml/acs')
-api.add_resource(EnterpriseSSOOidcLogin, '/enterprise/sso/oidc/login')
-api.add_resource(EnterpriseSSOOidcCallback, '/enterprise/sso/oidc/callback')

+ 3 - 4
api/controllers/console/feature.py

@@ -1,7 +1,6 @@
 from flask_login import current_user
 from flask_restful import Resource
 
-from services.enterprise.enterprise_feature_service import EnterpriseFeatureService
 from services.feature_service import FeatureService
 
 from . import api
@@ -15,10 +14,10 @@ class FeatureApi(Resource):
         return FeatureService.get_features(current_user.current_tenant_id).dict()
 
 
-class EnterpriseFeatureApi(Resource):
+class SystemFeatureApi(Resource):
     def get(self):
-        return EnterpriseFeatureService.get_enterprise_features().dict()
+        return FeatureService.get_system_features().dict()
 
 
 api.add_resource(FeatureApi, '/features')
-api.add_resource(EnterpriseFeatureApi, '/enterprise-features')
+api.add_resource(SystemFeatureApi, '/system-features')

+ 1 - 1
api/controllers/web/__init__.py

@@ -6,4 +6,4 @@ bp = Blueprint('web', __name__, url_prefix='/api')
 api = ExternalApi(bp)
 
 
-from . import app, audio, completion, conversation, file, message, passport, saved_message, site, workflow
+from . import app, audio, completion, conversation, feature, file, message, passport, saved_message, site, workflow

+ 1 - 5
api/controllers/web/app.py

@@ -1,14 +1,10 @@
-import json
-
 from flask import current_app
 from flask_restful import fields, marshal_with
 
 from controllers.web import api
 from controllers.web.error import AppUnavailableError
 from controllers.web.wraps import WebApiResource
-from extensions.ext_database import db
-from models.model import App, AppMode, AppModelConfig
-from models.tools import ApiToolProvider
+from models.model import App, AppMode
 from services.app_service import AppService
 
 

+ 6 - 0
api/controllers/web/error.py

@@ -115,3 +115,9 @@ class UnsupportedFileTypeError(BaseHTTPException):
     error_code = 'unsupported_file_type'
     description = "File type not allowed."
     code = 415
+
+
+class WebSSOAuthRequiredError(BaseHTTPException):
+    error_code = 'web_sso_auth_required'
+    description = "Web SSO authentication required."
+    code = 401

+ 12 - 0
api/controllers/web/feature.py

@@ -0,0 +1,12 @@
+from flask_restful import Resource
+
+from controllers.web import api
+from services.feature_service import FeatureService
+
+
+class SystemFeatureApi(Resource):
+    def get(self):
+        return FeatureService.get_system_features().dict()
+
+
+api.add_resource(SystemFeatureApi, '/system-features')

+ 11 - 1
api/controllers/web/passport.py

@@ -5,14 +5,21 @@ from flask_restful import Resource
 from werkzeug.exceptions import NotFound, Unauthorized
 
 from controllers.web import api
+from controllers.web.error import WebSSOAuthRequiredError
 from extensions.ext_database import db
 from libs.passport import PassportService
 from models.model import App, EndUser, Site
+from services.feature_service import FeatureService
 
 
 class PassportResource(Resource):
     """Base resource for passport."""
     def get(self):
+
+        system_features = FeatureService.get_system_features()
+        if system_features.sso_enforced_for_web:
+            raise WebSSOAuthRequiredError()
+
         app_code = request.headers.get('X-App-Code')
         if app_code is None:
             raise Unauthorized('X-App-Code header is missing.')
@@ -28,7 +35,7 @@ class PassportResource(Resource):
         app_model = db.session.query(App).filter(App.id == site.app_id).first()
         if not app_model or app_model.status != 'normal' or not app_model.enable_site:
             raise NotFound()
-        
+
         end_user = EndUser(
             tenant_id=app_model.tenant_id,
             app_id=app_model.id,
@@ -36,6 +43,7 @@ class PassportResource(Resource):
             is_anonymous=True,
             session_id=generate_session_id(),
         )
+
         db.session.add(end_user)
         db.session.commit()
 
@@ -53,8 +61,10 @@ class PassportResource(Resource):
             'access_token': tk,
         }
 
+
 api.add_resource(PassportResource, '/passport')
 
+
 def generate_session_id():
     """
     Generate a unique session ID.

+ 56 - 28
api/controllers/web/wraps.py

@@ -2,11 +2,13 @@ from functools import wraps
 
 from flask import request
 from flask_restful import Resource
-from werkzeug.exceptions import NotFound, Unauthorized
+from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
 
+from controllers.web.error import WebSSOAuthRequiredError
 from extensions.ext_database import db
 from libs.passport import PassportService
 from models.model import App, EndUser, Site
+from services.feature_service import FeatureService
 
 
 def validate_jwt_token(view=None):
@@ -21,34 +23,60 @@ def validate_jwt_token(view=None):
         return decorator(view)
     return decorator
 
+
 def decode_jwt_token():
-    auth_header = request.headers.get('Authorization')
-    if auth_header is None:
-        raise Unauthorized('Authorization header is missing.')
-
-    if ' ' not in auth_header:
-        raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
-    
-    auth_scheme, tk = auth_header.split(None, 1)
-    auth_scheme = auth_scheme.lower()
-
-    if auth_scheme != 'bearer':
-        raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
-    decoded = PassportService().verify(tk)
-    app_code = decoded.get('app_code')
-    app_model = db.session.query(App).filter(App.id == decoded['app_id']).first()
-    site = db.session.query(Site).filter(Site.code == app_code).first()
-    if not app_model:
-        raise NotFound()
-    if not app_code or not site:
-        raise Unauthorized('Site URL is no longer valid.')
-    if app_model.enable_site is False:
-        raise Unauthorized('Site is disabled.')
-    end_user = db.session.query(EndUser).filter(EndUser.id == decoded['end_user_id']).first()
-    if not end_user:
-        raise NotFound()
-
-    return app_model, end_user
+    system_features = FeatureService.get_system_features()
+
+    try:
+        auth_header = request.headers.get('Authorization')
+        if auth_header is None:
+            raise Unauthorized('Authorization header is missing.')
+
+        if ' ' not in auth_header:
+            raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
+
+        auth_scheme, tk = auth_header.split(None, 1)
+        auth_scheme = auth_scheme.lower()
+
+        if auth_scheme != 'bearer':
+            raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
+        decoded = PassportService().verify(tk)
+        app_code = decoded.get('app_code')
+        app_model = db.session.query(App).filter(App.id == decoded['app_id']).first()
+        site = db.session.query(Site).filter(Site.code == app_code).first()
+        if not app_model:
+            raise NotFound()
+        if not app_code or not site:
+            raise BadRequest('Site URL is no longer valid.')
+        if app_model.enable_site is False:
+            raise BadRequest('Site is disabled.')
+        end_user = db.session.query(EndUser).filter(EndUser.id == decoded['end_user_id']).first()
+        if not end_user:
+            raise NotFound()
+
+        _validate_web_sso_token(decoded, system_features)
+
+        return app_model, end_user
+    except Unauthorized as e:
+        if system_features.sso_enforced_for_web:
+            raise WebSSOAuthRequiredError()
+
+        raise Unauthorized(e.description)
+
+
+def _validate_web_sso_token(decoded, system_features):
+    # Check if SSO is enforced for web, and if the token source is not SSO, raise an error and redirect to SSO login
+    if system_features.sso_enforced_for_web:
+        source = decoded.get('token_source')
+        if not source or source != 'sso':
+            raise WebSSOAuthRequiredError()
+
+    # Check if SSO is not enforced for web, and if the token source is SSO, raise an error and redirect to normal passport login
+    if not system_features.sso_enforced_for_web:
+        source = decoded.get('token_source')
+        if source and source == 'sso':
+            raise Unauthorized('sso token expired.')
+
 
 class WebApiResource(Resource):
     method_decorators = [validate_jwt_token]

+ 0 - 28
api/services/enterprise/enterprise_feature_service.py

@@ -1,28 +0,0 @@
-from flask import current_app
-from pydantic import BaseModel
-
-from services.enterprise.enterprise_service import EnterpriseService
-
-
-class EnterpriseFeatureModel(BaseModel):
-    sso_enforced_for_signin: bool = False
-    sso_enforced_for_signin_protocol: str = ''
-
-
-class EnterpriseFeatureService:
-
-    @classmethod
-    def get_enterprise_features(cls) -> EnterpriseFeatureModel:
-        features = EnterpriseFeatureModel()
-
-        if current_app.config['ENTERPRISE_ENABLED']:
-            cls._fulfill_params_from_enterprise(features)
-
-        return features
-
-    @classmethod
-    def _fulfill_params_from_enterprise(cls, features):
-        enterprise_info = EnterpriseService.get_info()
-
-        features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin']
-        features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol']

+ 0 - 60
api/services/enterprise/enterprise_sso_service.py

@@ -1,60 +0,0 @@
-import logging
-
-from models.account import Account, AccountStatus
-from services.account_service import AccountService, TenantService
-from services.enterprise.base import EnterpriseRequest
-
-logger = logging.getLogger(__name__)
-
-
-class EnterpriseSSOService:
-
-    @classmethod
-    def get_sso_saml_login(cls) -> str:
-        return EnterpriseRequest.send_request('GET', '/sso/saml/login')
-
-    @classmethod
-    def post_sso_saml_acs(cls, saml_response: str) -> str:
-        response = EnterpriseRequest.send_request('POST', '/sso/saml/acs', json={'SAMLResponse': saml_response})
-        if 'email' not in response or response['email'] is None:
-            logger.exception(response)
-            raise Exception('Saml response is invalid')
-
-        return cls.login_with_email(response.get('email'))
-
-    @classmethod
-    def get_sso_oidc_login(cls):
-        return EnterpriseRequest.send_request('GET', '/sso/oidc/login')
-
-    @classmethod
-    def get_sso_oidc_callback(cls, args: dict):
-        state_from_query = args['state']
-        code_from_query = args['code']
-        state_from_cookies = args['oidc-state']
-
-        if state_from_cookies != state_from_query:
-            raise Exception('invalid state or code')
-
-        response = EnterpriseRequest.send_request('GET', '/sso/oidc/callback', params={'code': code_from_query})
-        if 'email' not in response or response['email'] is None:
-            logger.exception(response)
-            raise Exception('OIDC response is invalid')
-
-        return cls.login_with_email(response.get('email'))
-
-    @classmethod
-    def login_with_email(cls, email: str) -> str:
-        account = Account.query.filter_by(email=email).first()
-        if account is None:
-            raise Exception('account not found, please contact system admin to invite you to join in a workspace')
-
-        if account.status == AccountStatus.BANNED:
-            raise Exception('account is banned, please contact system admin')
-
-        tenants = TenantService.get_join_tenants(account)
-        if len(tenants) == 0:
-            raise Exception("workspace not found, please contact system admin to invite you to join in a workspace")
-
-        token = AccountService.get_account_jwt_token(account)
-
-        return token

+ 25 - 0
api/services/feature_service.py

@@ -2,6 +2,7 @@ from flask import current_app
 from pydantic import BaseModel
 
 from services.billing_service import BillingService
+from services.enterprise.enterprise_service import EnterpriseService
 
 
 class SubscriptionModel(BaseModel):
@@ -30,6 +31,13 @@ class FeatureModel(BaseModel):
     can_replace_logo: bool = False
 
 
+class SystemFeatureModel(BaseModel):
+    sso_enforced_for_signin: bool = False
+    sso_enforced_for_signin_protocol: str = ''
+    sso_enforced_for_web: bool = False
+    sso_enforced_for_web_protocol: str = ''
+
+
 class FeatureService:
 
     @classmethod
@@ -43,6 +51,15 @@ class FeatureService:
 
         return features
 
+    @classmethod
+    def get_system_features(cls) -> SystemFeatureModel:
+        system_features = SystemFeatureModel()
+
+        if current_app.config['ENTERPRISE_ENABLED']:
+            cls._fulfill_params_from_enterprise(system_features)
+
+        return system_features
+
     @classmethod
     def _fulfill_params_from_env(cls, features: FeatureModel):
         features.can_replace_logo = current_app.config['CAN_REPLACE_LOGO']
@@ -73,3 +90,11 @@ class FeatureService:
         features.docs_processing = billing_info['docs_processing']
         features.can_replace_logo = billing_info['can_replace_logo']
 
+    @classmethod
+    def _fulfill_params_from_enterprise(cls, features):
+        enterprise_info = EnterpriseService.get_info()
+
+        features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin']
+        features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol']
+        features.sso_enforced_for_web = enterprise_info['sso_enforced_for_web']
+        features.sso_enforced_for_web_protocol = enterprise_info['sso_enforced_for_web_protocol']

+ 0 - 1
web/app/(shareLayout)/chat/[token]/page.tsx

@@ -1,5 +1,4 @@
 'use client'
-
 import type { FC } from 'react'
 import React from 'react'
 

+ 78 - 3
web/app/(shareLayout)/chatbot/[token]/page.tsx

@@ -1,12 +1,87 @@
+'use client'
 import type { FC } from 'react'
-import React from 'react'
-
+import React, { useEffect } from 'react'
+import cn from 'classnames'
 import type { IMainProps } from '@/app/components/share/chat'
 import Main from '@/app/components/share/chatbot'
+import Loading from '@/app/components/base/loading'
+import { fetchSystemFeatures } from '@/service/share'
+import LogoSite from '@/app/components/base/logo/logo-site'
 
 const Chatbot: FC<IMainProps> = () => {
+  const [isSSOEnforced, setIsSSOEnforced] = React.useState(true)
+  const [loading, setLoading] = React.useState(true)
+
+  useEffect(() => {
+    fetchSystemFeatures().then((res) => {
+      setIsSSOEnforced(res.sso_enforced_for_web)
+      setLoading(false)
+    })
+  }, [])
+
   return (
-    <Main />
+    <>
+      {
+        loading
+          ? (
+            <div className="flex items-center justify-center h-full" >
+              <div className={
+                cn(
+                  'flex flex-col items-center w-full grow items-center justify-center',
+                  'px-6',
+                  'md:px-[108px]',
+                )
+              }>
+                <Loading type='area' />
+              </div>
+            </div >
+          )
+          : (
+            <>
+              {isSSOEnforced
+                ? (
+                  <div className={cn(
+                    'flex w-full min-h-screen',
+                    'sm:p-4 lg:p-8',
+                    'gap-x-20',
+                    'justify-center lg:justify-start',
+                  )}>
+                    <div className={
+                      cn(
+                        'flex w-full flex-col bg-white shadow rounded-2xl shrink-0',
+                        'space-between',
+                      )
+                    }>
+                      <div className='flex items-center justify-between p-6 w-full'>
+                        <LogoSite />
+                      </div>
+
+                      <div className={
+                        cn(
+                          'flex flex-col items-center w-full grow items-center justify-center',
+                          'px-6',
+                          'md:px-[108px]',
+                        )
+                      }>
+                        <div className='flex flex-col md:w-[400px]'>
+                          <div className="w-full mx-auto">
+                            <h2 className="text-[16px] font-bold text-gray-900">
+                          Warning: Chatbot is not available
+                            </h2>
+                            <p className="text-[16px] text-gray-600 mt-2">
+                          Because SSO is enforced. Please contact your administrator.
+                            </p>
+                          </div>
+                        </div>
+                      </div>
+                    </div>
+                  </div>
+                )
+                : <Main />
+              }
+            </>
+          )}
+    </>
   )
 }
 

+ 147 - 0
web/app/(shareLayout)/webapp-signin/page.tsx

@@ -0,0 +1,147 @@
+'use client'
+import cn from 'classnames'
+import { useRouter, useSearchParams } from 'next/navigation'
+import type { FC } from 'react'
+import React, { useEffect, useState } from 'react'
+import { useTranslation } from 'react-i18next'
+import Toast from '@/app/components/base/toast'
+import Button from '@/app/components/base/button'
+import { fetchSystemFeatures, fetchWebOIDCSSOUrl, fetchWebSAMLSSOUrl } from '@/service/share'
+import LogoSite from '@/app/components/base/logo/logo-site'
+import { setAccessToken } from '@/app/components/share/utils'
+
+const WebSSOForm: FC = () => {
+  const searchParams = useSearchParams()
+
+  const redirectUrl = searchParams.get('redirect_url')
+  const tokenFromUrl = searchParams.get('web_sso_token')
+  const message = searchParams.get('message')
+
+  const router = useRouter()
+  const { t } = useTranslation()
+
+  const [isLoading, setIsLoading] = useState(false)
+  const [protocal, setProtocal] = useState('')
+
+  useEffect(() => {
+    const fetchFeaturesAndSetToken = async () => {
+      await fetchSystemFeatures().then((res) => {
+        setProtocal(res.sso_enforced_for_web_protocol)
+      })
+
+      // Callback from SSO, process token and redirect
+      if (tokenFromUrl && redirectUrl) {
+        const appCode = redirectUrl.split('/').pop()
+        if (!appCode) {
+          Toast.notify({
+            type: 'error',
+            message: 'redirect url is invalid. App code is not found.',
+          })
+          return
+        }
+
+        await setAccessToken(appCode, tokenFromUrl)
+        router.push(redirectUrl)
+      }
+    }
+
+    fetchFeaturesAndSetToken()
+
+    if (message) {
+      Toast.notify({
+        type: 'error',
+        message,
+      })
+    }
+  }, [])
+
+  const handleSSOLogin = () => {
+    setIsLoading(true)
+
+    if (!redirectUrl) {
+      Toast.notify({
+        type: 'error',
+        message: 'redirect url is not found.',
+      })
+      setIsLoading(false)
+      return
+    }
+
+    const appCode = redirectUrl.split('/').pop()
+    if (!appCode) {
+      Toast.notify({
+        type: 'error',
+        message: 'redirect url is invalid. App code is not found.',
+      })
+      return
+    }
+
+    if (protocal === 'saml') {
+      fetchWebSAMLSSOUrl(appCode, redirectUrl).then((res) => {
+        router.push(res.url)
+      }).finally(() => {
+        setIsLoading(false)
+      })
+    }
+    else if (protocal === 'oidc') {
+      fetchWebOIDCSSOUrl(appCode, redirectUrl).then((res) => {
+        router.push(res.url)
+      }).finally(() => {
+        setIsLoading(false)
+      })
+    }
+    else {
+      Toast.notify({
+        type: 'error',
+        message: 'sso protocal is not supported.',
+      })
+      setIsLoading(false)
+    }
+  }
+
+  return (
+    <div className={cn(
+      'flex w-full min-h-screen',
+      'sm:p-4 lg:p-8',
+      'gap-x-20',
+      'justify-center lg:justify-start',
+    )}>
+      <div className={
+        cn(
+          'flex w-full flex-col bg-white shadow rounded-2xl shrink-0',
+          'space-between',
+        )
+      }>
+        <div className='flex items-center justify-between p-6 w-full'>
+          <LogoSite />
+        </div>
+
+        <div className={
+          cn(
+            'flex flex-col items-center w-full grow items-center justify-center',
+            'px-6',
+            'md:px-[108px]',
+          )
+        }>
+          <div className='flex flex-col md:w-[400px]'>
+            <div className="w-full mx-auto">
+              <h2 className="text-[32px] font-bold text-gray-900">{t('login.pageTitle')}</h2>
+            </div>
+            <div className="w-full mx-auto mt-10">
+              <Button
+                tabIndex={0}
+                type='primary'
+                onClick={() => { handleSSOLogin() }}
+                disabled={isLoading}
+                className="w-full !fone-medium !text-sm"
+              >{t('login.sso')}
+              </Button>
+            </div>
+          </div>
+        </div>
+      </div>
+    </div>
+  )
+}
+
+export default React.memo(WebSSOForm)

+ 36 - 0
web/app/components/share/utils.ts

@@ -1,4 +1,6 @@
+import { CONVERSATION_ID_INFO } from '../base/chat/constants'
 import { fetchAccessToken } from '@/service/share'
+
 export const checkOrSetAccessToken = async () => {
   const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0]
   const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' })
@@ -15,3 +17,37 @@ export const checkOrSetAccessToken = async () => {
     localStorage.setItem('token', JSON.stringify(accessTokenJson))
   }
 }
+
+export const setAccessToken = async (sharedToken: string, token: string) => {
+  const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' })
+  let accessTokenJson = { [sharedToken]: '' }
+  try {
+    accessTokenJson = JSON.parse(accessToken)
+  }
+  catch (e) {
+
+  }
+
+  localStorage.removeItem(CONVERSATION_ID_INFO)
+
+  accessTokenJson[sharedToken] = token
+  localStorage.setItem('token', JSON.stringify(accessTokenJson))
+}
+
+export const removeAccessToken = () => {
+  const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0]
+
+  const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' })
+  let accessTokenJson = { [sharedToken]: '' }
+  try {
+    accessTokenJson = JSON.parse(accessToken)
+  }
+  catch (e) {
+
+  }
+
+  localStorage.removeItem(CONVERSATION_ID_INFO)
+
+  delete accessTokenJson[sharedToken]
+  localStorage.setItem('token', JSON.stringify(accessTokenJson))
+}

+ 11 - 10
web/app/signin/page.tsx

@@ -6,19 +6,20 @@ import Loading from '../components/base/loading'
 import Forms from './forms'
 import Header from './_header'
 import style from './page.module.css'
-import EnterpriseSSOForm from './enterpriseSSOForm'
+import UserSSOForm from './userSSOForm'
 import { IS_CE_EDITION } from '@/config'
-import { getEnterpriseFeatures } from '@/service/enterprise'
-import type { EnterpriseFeatures } from '@/types/enterprise'
-import { defaultEnterpriseFeatures } from '@/types/enterprise'
+
+import type { SystemFeatures } from '@/types/feature'
+import { defaultSystemFeatures } from '@/types/feature'
+import { getSystemFeatures } from '@/service/common'
 
 const SignIn = () => {
   const [loading, setLoading] = useState<boolean>(true)
-  const [enterpriseFeatures, setEnterpriseFeatures] = useState<EnterpriseFeatures>(defaultEnterpriseFeatures)
+  const [systemFeatures, setSystemFeatures] = useState<SystemFeatures>(defaultSystemFeatures)
 
   useEffect(() => {
-    getEnterpriseFeatures().then((res) => {
-      setEnterpriseFeatures(res)
+    getSystemFeatures().then((res) => {
+      setSystemFeatures(res)
     }).finally(() => {
       setLoading(false)
     })
@@ -70,7 +71,7 @@ gtag('config', 'AW-11217955271"');
             </div>
           )}
 
-          {!loading && !enterpriseFeatures.sso_enforced_for_signin && (
+          {!loading && !systemFeatures.sso_enforced_for_signin && (
             <>
               <Forms />
               <div className='px-8 py-6 text-sm font-normal text-gray-500'>
@@ -79,8 +80,8 @@ gtag('config', 'AW-11217955271"');
             </>
           )}
 
-          {!loading && enterpriseFeatures.sso_enforced_for_signin && (
-            <EnterpriseSSOForm protocol={enterpriseFeatures.sso_enforced_for_signin_protocol} />
+          {!loading && systemFeatures.sso_enforced_for_signin && (
+            <UserSSOForm protocol={systemFeatures.sso_enforced_for_signin_protocol} />
           )}
         </div>
 

+ 7 - 7
web/app/signin/enterpriseSSOForm.tsx → web/app/signin/userSSOForm.tsx

@@ -5,14 +5,14 @@ import type { FC } from 'react'
 import { useEffect, useState } from 'react'
 import { useTranslation } from 'react-i18next'
 import Toast from '@/app/components/base/toast'
-import { getOIDCSSOUrl, getSAMLSSOUrl } from '@/service/enterprise'
+import { getUserOIDCSSOUrl, getUserSAMLSSOUrl } from '@/service/sso'
 import Button from '@/app/components/base/button'
 
-type EnterpriseSSOFormProps = {
+type UserSSOFormProps = {
   protocol: string
 }
 
-const EnterpriseSSOForm: FC<EnterpriseSSOFormProps> = ({
+const UserSSOForm: FC<UserSSOFormProps> = ({
   protocol,
 }) => {
   const searchParams = useSearchParams()
@@ -41,15 +41,15 @@ const EnterpriseSSOForm: FC<EnterpriseSSOFormProps> = ({
   const handleSSOLogin = () => {
     setIsLoading(true)
     if (protocol === 'saml') {
-      getSAMLSSOUrl().then((res) => {
+      getUserSAMLSSOUrl().then((res) => {
         router.push(res.url)
       }).finally(() => {
         setIsLoading(false)
       })
     }
     else {
-      getOIDCSSOUrl().then((res) => {
-        document.cookie = `oidc-state=${res.state}`
+      getUserOIDCSSOUrl().then((res) => {
+        document.cookie = `user-oidc-state=${res.state}`
         router.push(res.url)
       }).finally(() => {
         setIsLoading(false)
@@ -84,4 +84,4 @@ const EnterpriseSSOForm: FC<EnterpriseSSOFormProps> = ({
   )
 }
 
-export default EnterpriseSSOForm
+export default UserSSOForm

+ 24 - 0
web/service/base.ts

@@ -10,6 +10,7 @@ import type {
   WorkflowFinishedResponse,
   WorkflowStartedResponse,
 } from '@/types/workflow'
+import { removeAccessToken } from '@/app/components/share/utils'
 const TIME_OUT = 100000
 
 const ContentType = {
@@ -97,6 +98,10 @@ function unicodeToChar(text: string) {
   })
 }
 
+function requiredWebSSOLogin() {
+  globalThis.location.href = `/webapp-signin?redirect_url=${globalThis.location.pathname}`
+}
+
 export function format(text: string) {
   let res = text.trim()
   if (res.startsWith('\n'))
@@ -308,6 +313,15 @@ const baseFetch = <T>(
                   return bodyJson.then((data: ResponseError) => {
                     if (!silent)
                       Toast.notify({ type: 'error', message: data.message })
+
+                    if (data.code === 'web_sso_auth_required')
+                      requiredWebSSOLogin()
+
+                    if (data.code === 'unauthorized') {
+                      removeAccessToken()
+                      globalThis.location.reload()
+                    }
+
                     return Promise.reject(data)
                   })
                 }
@@ -467,6 +481,16 @@ export const ssePost = (
       if (!/^(2|3)\d{2}$/.test(String(res.status))) {
         res.json().then((data: any) => {
           Toast.notify({ type: 'error', message: data.message || 'Server Error' })
+
+          if (isPublicAPI) {
+            if (data.code === 'web_sso_auth_required')
+              requiredWebSSOLogin()
+
+            if (data.code === 'unauthorized') {
+              removeAccessToken()
+              globalThis.location.reload()
+            }
+          }
         })
         onError?.('Server Error')
         return

+ 5 - 0
web/service/common.ts

@@ -34,6 +34,7 @@ import type {
   ModelProvider,
 } from '@/app/components/header/account-setting/model-provider-page/declarations'
 import type { RETRIEVE_METHOD } from '@/types/app'
+import type { SystemFeatures } from '@/types/feature'
 
 export const login: Fetcher<CommonResponse & { data: string }, { url: string; body: Record<string, any> }> = ({ url, body }) => {
   return post(url, { body }) as Promise<CommonResponse & { data: string }>
@@ -271,3 +272,7 @@ type RetrievalMethodsRes = {
 export const fetchSupportRetrievalMethods: Fetcher<RetrievalMethodsRes, string> = (url) => {
   return get<RetrievalMethodsRes>(url)
 }
+
+export const getSystemFeatures = () => {
+  return get<SystemFeatures>('/system-features')
+}

+ 0 - 14
web/service/enterprise.ts

@@ -1,14 +0,0 @@
-import { get } from './base'
-import type { EnterpriseFeatures } from '@/types/enterprise'
-
-export const getEnterpriseFeatures = () => {
-  return get<EnterpriseFeatures>('/enterprise-features')
-}
-
-export const getSAMLSSOUrl = () => {
-  return get<{ url: string }>('/enterprise/sso/saml/login')
-}
-
-export const getOIDCSSOUrl = () => {
-  return get<{ url: string; state: string }>('/enterprise/sso/oidc/login')
-}

+ 24 - 0
web/service/share.ts

@@ -11,6 +11,7 @@ import type {
   ConversationItem,
 } from '@/models/share'
 import type { ChatConfig } from '@/app/components/base/chat/types'
+import type { SystemFeatures } from '@/types/feature'
 
 function getAction(action: 'get' | 'post' | 'del' | 'patch', isInstalledApp: boolean) {
   switch (action) {
@@ -135,6 +136,29 @@ export const fetchAppParams = async (isInstalledApp: boolean, installedAppId = '
   return (getAction('get', isInstalledApp))(getUrl('parameters', isInstalledApp, installedAppId)) as Promise<ChatConfig>
 }
 
+export const fetchSystemFeatures = async () => {
+  return (getAction('get', false))(getUrl('system-features', false, '')) as Promise<SystemFeatures>
+}
+
+export const fetchWebSAMLSSOUrl = async (appCode: string, redirectUrl: string) => {
+  return (getAction('get', false))(getUrl('/enterprise/sso/saml/login', false, ''), {
+    params: {
+      app_code: appCode,
+      redirect_url: redirectUrl,
+    },
+  }) as Promise<{ url: string }>
+}
+
+export const fetchWebOIDCSSOUrl = async (appCode: string, redirectUrl: string) => {
+  return (getAction('get', false))(getUrl('/enterprise/sso/oidc/login', false, ''), {
+    params: {
+      app_code: appCode,
+      redirect_url: redirectUrl,
+    },
+
+  }) as Promise<{ url: string }>
+}
+
 export const fetchAppMeta = async (isInstalledApp: boolean, installedAppId = '') => {
   return (getAction('get', isInstalledApp))(getUrl('meta', isInstalledApp, installedAppId)) as Promise<AppMeta>
 }

+ 9 - 0
web/service/sso.ts

@@ -0,0 +1,9 @@
+import { get } from './base'
+
+export const getUserSAMLSSOUrl = () => {
+  return get<{ url: string }>('/enterprise/sso/saml/login')
+}
+
+export const getUserOIDCSSOUrl = () => {
+  return get<{ url: string; state: string }>('/enterprise/sso/oidc/login')
+}

+ 0 - 9
web/types/enterprise.ts

@@ -1,9 +0,0 @@
-export type EnterpriseFeatures = {
-  sso_enforced_for_signin: boolean
-  sso_enforced_for_signin_protocol: string
-}
-
-export const defaultEnterpriseFeatures: EnterpriseFeatures = {
-  sso_enforced_for_signin: false,
-  sso_enforced_for_signin_protocol: '',
-}

+ 13 - 0
web/types/feature.ts

@@ -0,0 +1,13 @@
+export type SystemFeatures = {
+  sso_enforced_for_signin: boolean
+  sso_enforced_for_signin_protocol: string
+  sso_enforced_for_web: boolean
+  sso_enforced_for_web_protocol: string
+}
+
+export const defaultSystemFeatures: SystemFeatures = {
+  sso_enforced_for_signin: false,
+  sso_enforced_for_signin_protocol: '',
+  sso_enforced_for_web: false,
+  sso_enforced_for_web_protocol: '',
+}