瀏覽代碼

feat(api/auth): switch-to-stateful-authentication (#5438)

-LAN- 10 月之前
父節點
當前提交
1336b844fd

+ 20 - 21
api/app.py

@@ -2,7 +2,7 @@ import os
 
 from configs.app_configs import DifyConfigs
 
-if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
+if not os.environ.get("DEBUG") or os.environ.get("DEBUG", "false").lower() != 'true':
     from gevent import monkey
 
     monkey.patch_all()
@@ -152,27 +152,26 @@ def initialize_extensions(app):
 @login_manager.request_loader
 def load_user_from_request(request_from_flask_login):
     """Load user based on the request."""
-    if request.blueprint in ['console', 'inner_api']:
-        # Check if the user_id contains a dot, indicating the old format
-        auth_header = request.headers.get('Authorization', '')
-        if not auth_header:
-            auth_token = request.args.get('_token')
-            if not auth_token:
-                raise Unauthorized('Invalid Authorization token.')
-        else:
-            if ' ' not in auth_header:
-                raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
-            auth_scheme, auth_token = auth_header.split(None, 1)
-            auth_scheme = auth_scheme.lower()
-            if auth_scheme != 'bearer':
-                raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
-
-        decoded = PassportService().verify(auth_token)
-        user_id = decoded.get('user_id')
-
-        return AccountService.load_user(user_id)
-    else:
+    if request.blueprint not in ['console', 'inner_api']:
         return None
+    # Check if the user_id contains a dot, indicating the old format
+    auth_header = request.headers.get('Authorization', '')
+    if not auth_header:
+        auth_token = request.args.get('_token')
+        if not auth_token:
+            raise Unauthorized('Invalid Authorization token.')
+    else:
+        if ' ' not in auth_header:
+            raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
+        auth_scheme, auth_token = auth_header.split(None, 1)
+        auth_scheme = auth_scheme.lower()
+        if auth_scheme != 'bearer':
+            raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
+
+    decoded = PassportService().verify(auth_token)
+    user_id = decoded.get('user_id')
+
+    return AccountService.load_logged_in_account(account_id=user_id, token=auth_token)
 
 
 @login_manager.unauthorized_handler

+ 8 - 5
api/controllers/console/auth/login.py

@@ -1,3 +1,5 @@
+from typing import cast
+
 import flask_login
 from flask import current_app, request
 from flask_restful import Resource, reqparse
@@ -5,8 +7,9 @@ from flask_restful import Resource, reqparse
 import services
 from controllers.console import api
 from controllers.console.setup import setup_required
-from libs.helper import email
+from libs.helper import email, get_remote_ip
 from libs.password import valid_password
+from models.account import Account
 from services.account_service import AccountService, TenantService
 
 
@@ -34,10 +37,7 @@ class LoginApi(Resource):
         if len(tenants) == 0:
             return {'result': 'fail', 'data': 'workspace not found, please contact system admin to invite you to join in a workspace'}
 
-        AccountService.update_last_login(account, request)
-
-        # todo: return the user info
-        token = AccountService.get_account_jwt_token(account)
+        token = AccountService.login(account, ip_address=get_remote_ip(request))
 
         return {'result': 'success', 'data': token}
 
@@ -46,6 +46,9 @@ class LogoutApi(Resource):
 
     @setup_required
     def get(self):
+        account = cast(Account, flask_login.current_user)
+        token = request.headers.get('Authorization', '').split(' ')[1]
+        AccountService.logout(account=account, token=token)
         flask_login.logout_user()
         return {'result': 'success'}
 

+ 2 - 3
api/controllers/console/auth/oauth.py

@@ -8,6 +8,7 @@ from flask_restful import Resource
 
 from constants.languages import languages
 from extensions.ext_database import db
+from libs.helper import get_remote_ip
 from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
 from models.account import Account, AccountStatus
 from services.account_service import AccountService, RegisterService, TenantService
@@ -78,9 +79,7 @@ class OAuthCallback(Resource):
 
         TenantService.create_owner_tenant_if_not_exist(account)
 
-        AccountService.update_last_login(account, request)
-
-        token = AccountService.get_account_jwt_token(account)
+        token = AccountService.login(account, ip_address=get_remote_ip(request))
 
         return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?console_token={token}')
 

+ 2 - 2
api/controllers/console/setup.py

@@ -4,7 +4,7 @@ from flask import current_app, request
 from flask_restful import Resource, reqparse
 
 from extensions.ext_database import db
-from libs.helper import email, str_len
+from libs.helper import email, get_remote_ip, str_len
 from libs.password import valid_password
 from models.model import DifySetup
 from services.account_service import AccountService, RegisterService, TenantService
@@ -61,7 +61,7 @@ class SetupApi(Resource):
         TenantService.create_owner_tenant_if_not_exist(account)
 
         setup()
-        AccountService.update_last_login(account, request)
+        AccountService.update_last_login(account, ip_address=get_remote_ip(request))
 
         return {'result': 'success'}, 201
 

+ 1 - 1
api/libs/helper.py

@@ -140,7 +140,7 @@ def generate_string(n):
     return result
 
 
-def get_remote_ip(request):
+def get_remote_ip(request) -> str:
     if request.headers.get('CF-Connecting-IP'):
         return request.headers.get('Cf-Connecting-Ip')
     elif request.headers.getlist("X-Forwarded-For"):

+ 3 - 1
api/services/__init__.py

@@ -1 +1,3 @@
-import services.errors
+from . import errors
+
+__all__ = ['errors']

+ 25 - 5
api/services/account_service.py

@@ -13,7 +13,6 @@ from werkzeug.exceptions import Unauthorized
 from constants.languages import language_timezone_mapping, languages
 from events.tenant_event import tenant_was_created
 from extensions.ext_redis import redis_client
-from libs.helper import get_remote_ip
 from libs.passport import PassportService
 from libs.password import compare_password, hash_password, valid_password
 from libs.rsa import generate_key_pair
@@ -67,10 +66,10 @@ class AccountService:
 
 
     @staticmethod
-    def get_account_jwt_token(account):
+    def get_account_jwt_token(account, *, exp: timedelta = timedelta(days=30)):
         payload = {
             "user_id": account.id,
-            "exp": datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(days=30),
+            "exp": datetime.now(timezone.utc).replace(tzinfo=None) + exp,
             "iss": current_app.config['EDITION'],
             "sub": 'Console API Passport',
         }
@@ -195,14 +194,35 @@ class AccountService:
         return account
 
     @staticmethod
-    def update_last_login(account: Account, request) -> None:
+    def update_last_login(account: Account, *, ip_address: str) -> None:
         """Update last login time and ip"""
         account.last_login_at = datetime.now(timezone.utc).replace(tzinfo=None)
-        account.last_login_ip = get_remote_ip(request)
+        account.last_login_ip = ip_address
         db.session.add(account)
         db.session.commit()
         logging.info(f'Account {account.id} logged in successfully.')
 
+    @staticmethod
+    def login(account: Account, *, ip_address: Optional[str] = None):
+        if ip_address:
+            AccountService.update_last_login(account, ip_address=ip_address)
+        exp = timedelta(days=30)
+        token = AccountService.get_account_jwt_token(account, exp=exp)
+        redis_client.set(_get_login_cache_key(account_id=account.id, token=token), '1', ex=int(exp.total_seconds()))
+        return token
+
+    @staticmethod
+    def logout(*, account: Account, token: str):
+        redis_client.delete(_get_login_cache_key(account_id=account.id, token=token))
+
+    @staticmethod
+    def load_logged_in_account(*, account_id: str, token: str):
+        if not redis_client.get(_get_login_cache_key(account_id=account_id, token=token)):
+            return None
+        return AccountService.load_user(account_id)
+
+def _get_login_cache_key(*, account_id: str, token: str):
+    return f"account_login:{account_id}:{token}"
 
 class TenantService:
 

+ 27 - 4
api/services/errors/__init__.py

@@ -1,6 +1,29 @@
+from . import (
+    account,
+    app,
+    app_model_config,
+    audio,
+    base,
+    completion,
+    conversation,
+    dataset,
+    document,
+    file,
+    index,
+    message,
+)
+
 __all__ = [
-    'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset',
-    'app', 'completion', 'audio', 'file'
+    "base",
+    "conversation",
+    "message",
+    "index",
+    "app_model_config",
+    "account",
+    "document",
+    "dataset",
+    "app",
+    "completion",
+    "audio",
+    "file",
 ]
-
-from . import *