Browse Source

Feat/implement-refresh-tokens (#9233)

-LAN- 6 tháng trước cách đây
mục cha
commit
f73751843f

+ 3 - 0
api/.env.example

@@ -20,6 +20,9 @@ FILES_URL=http://127.0.0.1:5001
 # The time in seconds after the signature is rejected
 FILES_ACCESS_TIMEOUT=300
 
+# Access token expiration time in minutes
+ACCESS_TOKEN_EXPIRE_MINUTES=60
+
 # celery configuration
 CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
 

+ 1 - 1
api/app.py

@@ -183,7 +183,7 @@ def load_user_from_request(request_from_flask_login):
     decoded = PassportService().verify(auth_token)
     user_id = decoded.get("user_id")
 
-    logged_in_account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token)
+    logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
     if logged_in_account:
         contexts.tenant_id.set(logged_in_account.current_tenant_id)
     return logged_in_account

+ 10 - 5
api/configs/feature/__init__.py

@@ -360,9 +360,9 @@ class WorkflowConfig(BaseSettings):
     )
 
 
-class OAuthConfig(BaseSettings):
+class AuthConfig(BaseSettings):
     """
-    Configuration for OAuth authentication
+    Configuration for authentication and OAuth
     """
 
     OAUTH_REDIRECT_PATH: str = Field(
@@ -371,7 +371,7 @@ class OAuthConfig(BaseSettings):
     )
 
     GITHUB_CLIENT_ID: Optional[str] = Field(
-        description="GitHub OAuth client secret",
+        description="GitHub OAuth client ID",
         default=None,
     )
 
@@ -390,6 +390,11 @@ class OAuthConfig(BaseSettings):
         default=None,
     )
 
+    ACCESS_TOKEN_EXPIRE_MINUTES: PositiveInt = Field(
+        description="Expiration time for access tokens in minutes",
+        default=60,
+    )
+
 
 class ModerationConfig(BaseSettings):
     """
@@ -607,6 +612,7 @@ class PositionConfig(BaseSettings):
 class FeatureConfig(
     # place the configs in alphabet order
     AppExecutionConfig,
+    AuthConfig,  # Changed from OAuthConfig to AuthConfig
     BillingConfig,
     CodeExecutionSandboxConfig,
     DataSetConfig,
@@ -621,14 +627,13 @@ class FeatureConfig(
     MailConfig,
     ModelLoadBalanceConfig,
     ModerationConfig,
-    OAuthConfig,
+    PositionConfig,
     RagEtlConfig,
     SecurityConfig,
     ToolConfig,
     UpdateConfig,
     WorkflowConfig,
     WorkspaceConfig,
-    PositionConfig,
     # hosted services config
     HostedServiceConfig,
     CeleryBeatConfig,

+ 18 - 5
api/controllers/console/auth/login.py

@@ -7,7 +7,7 @@ from flask_restful import Resource, reqparse
 import services
 from controllers.console import api
 from controllers.console.setup import setup_required
-from libs.helper import email, get_remote_ip
+from libs.helper import email, extract_remote_ip
 from libs.password import valid_password
 from models.account import Account
 from services.account_service import AccountService, TenantService
@@ -40,17 +40,16 @@ class LoginApi(Resource):
                 "data": "workspace not found, please contact system admin to invite you to join in a workspace",
             }
 
-        token = AccountService.login(account, ip_address=get_remote_ip(request))
+        token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
 
-        return {"result": "success", "data": token}
+        return {"result": "success", "data": token_pair.model_dump()}
 
 
 class LogoutApi(Resource):
     @setup_required
     def get(self):
         account = cast(Account, flask_login.current_user)
-        token = request.headers.get("Authorization", "").split(" ")[1]
-        AccountService.logout(account=account, token=token)
+        AccountService.logout(account=account)
         flask_login.logout_user()
         return {"result": "success"}
 
@@ -106,5 +105,19 @@ class ResetPasswordApi(Resource):
         return {"result": "success"}
 
 
+class RefreshTokenApi(Resource):
+    def post(self):
+        parser = reqparse.RequestParser()
+        parser.add_argument("refresh_token", type=str, required=True, location="json")
+        args = parser.parse_args()
+
+        try:
+            new_token_pair = AccountService.refresh_token(args["refresh_token"])
+            return {"result": "success", "data": new_token_pair.model_dump()}
+        except Exception as e:
+            return {"result": "fail", "data": str(e)}, 401
+
+
 api.add_resource(LoginApi, "/login")
 api.add_resource(LogoutApi, "/logout")
+api.add_resource(RefreshTokenApi, "/refresh-token")

+ 8 - 3
api/controllers/console/auth/oauth.py

@@ -9,7 +9,7 @@ from flask_restful import Resource
 from configs import dify_config
 from constants.languages import languages
 from extensions.ext_database import db
-from libs.helper import get_remote_ip
+from libs.helper import extract_remote_ip
 from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
 from models.account import Account, AccountStatus
 from services.account_service import AccountService, RegisterService, TenantService
@@ -81,9 +81,14 @@ class OAuthCallback(Resource):
 
         TenantService.create_owner_tenant_if_not_exist(account)
 
-        token = AccountService.login(account, ip_address=get_remote_ip(request))
+        token_pair = AccountService.login(
+            account=account,
+            ip_address=extract_remote_ip(request),
+        )
 
-        return redirect(f"{dify_config.CONSOLE_WEB_URL}?console_token={token}")
+        return redirect(
+            f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
+        )
 
 
 def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:

+ 2 - 2
api/controllers/console/setup.py

@@ -4,7 +4,7 @@ from flask import request
 from flask_restful import Resource, reqparse
 
 from configs import dify_config
-from libs.helper import StrLen, email, get_remote_ip
+from libs.helper import StrLen, email, extract_remote_ip
 from libs.password import valid_password
 from models.model import DifySetup
 from services.account_service import RegisterService, TenantService
@@ -46,7 +46,7 @@ class SetupApi(Resource):
 
         # setup
         RegisterService.setup(
-            email=args["email"], name=args["name"], password=args["password"], ip_address=get_remote_ip(request)
+            email=args["email"], name=args["name"], password=args["password"], ip_address=extract_remote_ip(request)
         )
 
         return {"result": "success"}, 201

+ 1 - 1
api/libs/helper.py

@@ -162,7 +162,7 @@ def generate_string(n):
     return result
 
 
-def get_remote_ip(request) -> str:
+def extract_remote_ip(request) -> str:
     if request.headers.get("CF-Connecting-IP"):
         return request.headers.get("Cf-Connecting-Ip")
     elif request.headers.getlist("X-Forwarded-For"):

+ 76 - 21
api/services/account_service.py

@@ -7,6 +7,7 @@ from datetime import datetime, timedelta, timezone
 from hashlib import sha256
 from typing import Any, Optional
 
+from pydantic import BaseModel
 from sqlalchemy import func
 from werkzeug.exceptions import Unauthorized
 
@@ -49,9 +50,39 @@ from tasks.mail_invite_member_task import send_invite_member_mail_task
 from tasks.mail_reset_password_task import send_reset_password_mail_task
 
 
+class TokenPair(BaseModel):
+    access_token: str
+    refresh_token: str
+
+
+REFRESH_TOKEN_PREFIX = "refresh_token:"
+ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:"
+REFRESH_TOKEN_EXPIRY = timedelta(days=30)
+
+
 class AccountService:
     reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=5, time_window=60 * 60)
 
+    @staticmethod
+    def _get_refresh_token_key(refresh_token: str) -> str:
+        return f"{REFRESH_TOKEN_PREFIX}{refresh_token}"
+
+    @staticmethod
+    def _get_account_refresh_token_key(account_id: str) -> str:
+        return f"{ACCOUNT_REFRESH_TOKEN_PREFIX}{account_id}"
+
+    @staticmethod
+    def _store_refresh_token(refresh_token: str, account_id: str) -> None:
+        redis_client.setex(AccountService._get_refresh_token_key(refresh_token), REFRESH_TOKEN_EXPIRY, account_id)
+        redis_client.setex(
+            AccountService._get_account_refresh_token_key(account_id), REFRESH_TOKEN_EXPIRY, refresh_token
+        )
+
+    @staticmethod
+    def _delete_refresh_token(refresh_token: str, account_id: str) -> None:
+        redis_client.delete(AccountService._get_refresh_token_key(refresh_token))
+        redis_client.delete(AccountService._get_account_refresh_token_key(account_id))
+
     @staticmethod
     def load_user(user_id: str) -> None | Account:
         account = Account.query.filter_by(id=user_id).first()
@@ -61,9 +92,7 @@ class AccountService:
         if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}:
             raise Unauthorized("Account is banned or closed.")
 
-        current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by(
-            account_id=account.id, current=True
-        ).first()
+        current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
         if current_tenant:
             account.current_tenant_id = current_tenant.tenant_id
         else:
@@ -84,10 +113,12 @@ class AccountService:
         return account
 
     @staticmethod
-    def get_account_jwt_token(account, *, exp: timedelta = timedelta(days=30)):
+    def get_account_jwt_token(account: Account) -> str:
+        exp_dt = datetime.now(timezone.utc) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)
+        exp = int(exp_dt.timestamp())
         payload = {
             "user_id": account.id,
-            "exp": datetime.now(timezone.utc).replace(tzinfo=None) + exp,
+            "exp": exp,
             "iss": dify_config.EDITION,
             "sub": "Console API Passport",
         }
@@ -213,7 +244,7 @@ class AccountService:
         return account
 
     @staticmethod
-    def update_last_login(account: Account, *, ip_address: str) -> None:
+    def update_login_info(account: Account, *, ip_address: str) -> None:
         """Update last login time and ip"""
         account.last_login_at = datetime.now(timezone.utc).replace(tzinfo=None)
         account.last_login_ip = ip_address
@@ -221,22 +252,45 @@ class AccountService:
         db.session.commit()
 
     @staticmethod
-    def login(account: Account, *, ip_address: Optional[str] = None):
+    def login(account: Account, *, ip_address: Optional[str] = None) -> TokenPair:
         if ip_address:
-            AccountService.update_last_login(account, ip_address=ip_address)
-        exp = timedelta(days=30)
-        token = AccountService.get_account_jwt_token(account, exp=exp)
-        redis_client.set(_get_login_cache_key(account_id=account.id, token=token), "1", ex=int(exp.total_seconds()))
-        return token
+            AccountService.update_login_info(account=account, ip_address=ip_address)
+
+        access_token = AccountService.get_account_jwt_token(account=account)
+        refresh_token = _generate_refresh_token()
+
+        AccountService._store_refresh_token(refresh_token, account.id)
+
+        return TokenPair(access_token=access_token, refresh_token=refresh_token)
 
     @staticmethod
-    def logout(*, account: Account, token: str):
-        redis_client.delete(_get_login_cache_key(account_id=account.id, token=token))
+    def logout(*, account: Account) -> None:
+        refresh_token = redis_client.get(AccountService._get_account_refresh_token_key(account.id))
+        if refresh_token:
+            AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id)
 
     @staticmethod
-    def load_logged_in_account(*, account_id: str, token: str):
-        if not redis_client.get(_get_login_cache_key(account_id=account_id, token=token)):
-            return None
+    def refresh_token(refresh_token: str) -> TokenPair:
+        # Verify the refresh token
+        account_id = redis_client.get(AccountService._get_refresh_token_key(refresh_token))
+        if not account_id:
+            raise ValueError("Invalid refresh token")
+
+        account = AccountService.load_user(account_id.decode("utf-8"))
+        if not account:
+            raise ValueError("Invalid account")
+
+        # Generate new access token and refresh token
+        new_access_token = AccountService.get_account_jwt_token(account)
+        new_refresh_token = _generate_refresh_token()
+
+        AccountService._delete_refresh_token(refresh_token, account.id)
+        AccountService._store_refresh_token(new_refresh_token, account.id)
+
+        return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token)
+
+    @staticmethod
+    def load_logged_in_account(*, account_id: str):
         return AccountService.load_user(account_id)
 
     @classmethod
@@ -258,10 +312,6 @@ class AccountService:
         return TokenManager.get_token_data(token, "reset_password")
 
 
-def _get_login_cache_key(*, account_id: str, token: str):
-    return f"account_login:{account_id}:{token}"
-
-
 class TenantService:
     @staticmethod
     def create_tenant(name: str) -> Tenant:
@@ -698,3 +748,8 @@ class RegisterService:
 
             invitation = json.loads(data)
             return invitation
+
+
+def _generate_refresh_token(length: int = 64):
+    token = secrets.token_hex(length)
+    return token

+ 3 - 0
docker/.env.example

@@ -91,6 +91,9 @@ MIGRATION_ENABLED=true
 # The default value is 300 seconds.
 FILES_ACCESS_TIMEOUT=300
 
+# Access token expiration time in minutes
+ACCESS_TOKEN_EXPIRE_MINUTES=60
+
 # The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer.
 APP_MAX_ACTIVE_REQUESTS=0
 

+ 1 - 0
docker/docker-compose.yaml

@@ -47,6 +47,7 @@ x-shared-env: &shared-api-worker-env
   REDIS_SENTINEL_SERVICE_NAME: ${REDIS_SENTINEL_SERVICE_NAME:-}
   REDIS_SENTINEL_USERNAME: ${REDIS_SENTINEL_USERNAME:-}
   REDIS_SENTINEL_PASSWORD: ${REDIS_SENTINEL_PASSWORD:-}
+  ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60}
   REDIS_SENTINEL_SOCKET_TIMEOUT: ${REDIS_SENTINEL_SOCKET_TIMEOUT:-0.1}
   CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1}
   BROKER_USE_SSL: ${BROKER_USE_SSL:-false}