|
@@ -1,5 +1,5 @@
|
|
|
from collections.abc import Callable
|
|
|
-from datetime import UTC, datetime
|
|
|
+from datetime import UTC, datetime, timedelta
|
|
|
from enum import Enum
|
|
|
from functools import wraps
|
|
|
from typing import Optional
|
|
@@ -8,6 +8,8 @@ from flask import current_app, request
|
|
|
from flask_login import user_logged_in # type: ignore
|
|
|
from flask_restful import Resource # type: ignore
|
|
|
from pydantic import BaseModel
|
|
|
+from sqlalchemy import select, update
|
|
|
+from sqlalchemy.orm import Session
|
|
|
from werkzeug.exceptions import Forbidden, Unauthorized
|
|
|
|
|
|
from extensions.ext_database import db
|
|
@@ -174,7 +176,7 @@ def validate_dataset_token(view=None):
|
|
|
return decorator
|
|
|
|
|
|
|
|
|
-def validate_and_get_api_token(scope=None):
|
|
|
+def validate_and_get_api_token(scope: str | None = None):
|
|
|
"""
|
|
|
Validate and get API token.
|
|
|
"""
|
|
@@ -188,20 +190,25 @@ def validate_and_get_api_token(scope=None):
|
|
|
if auth_scheme != "bearer":
|
|
|
raise Unauthorized("Authorization scheme must be 'Bearer'")
|
|
|
|
|
|
- api_token = (
|
|
|
- db.session.query(ApiToken)
|
|
|
- .filter(
|
|
|
- ApiToken.token == auth_token,
|
|
|
- ApiToken.type == scope,
|
|
|
+ current_time = datetime.now(UTC).replace(tzinfo=None)
|
|
|
+ cutoff_time = current_time - timedelta(minutes=1)
|
|
|
+ with Session(db.engine, expire_on_commit=False) as session:
|
|
|
+ update_stmt = (
|
|
|
+ update(ApiToken)
|
|
|
+ .where(ApiToken.token == auth_token, ApiToken.last_used_at < cutoff_time, ApiToken.type == scope)
|
|
|
+ .values(last_used_at=current_time)
|
|
|
+ .returning(ApiToken)
|
|
|
)
|
|
|
- .first()
|
|
|
- )
|
|
|
-
|
|
|
- if not api_token:
|
|
|
- raise Unauthorized("Access token is invalid")
|
|
|
-
|
|
|
- api_token.last_used_at = datetime.now(UTC).replace(tzinfo=None)
|
|
|
- db.session.commit()
|
|
|
+ result = session.execute(update_stmt)
|
|
|
+ api_token = result.scalar_one_or_none()
|
|
|
+
|
|
|
+ if not api_token:
|
|
|
+ stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
|
|
|
+ api_token = session.scalar(stmt)
|
|
|
+ if not api_token:
|
|
|
+ raise Unauthorized("Access token is invalid")
|
|
|
+ else:
|
|
|
+ session.commit()
|
|
|
|
|
|
return api_token
|
|
|
|