wraps.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. # -*- coding:utf-8 -*-
  2. from datetime import datetime
  3. from functools import wraps
  4. from flask import request
  5. from flask_restful import Resource
  6. from werkzeug.exceptions import NotFound, Unauthorized
  7. from extensions.ext_database import db
  8. from models.dataset import Dataset
  9. from models.model import ApiToken, App
  10. def validate_app_token(view=None):
  11. def decorator(view):
  12. @wraps(view)
  13. def decorated(*args, **kwargs):
  14. api_token = validate_and_get_api_token('app')
  15. app_model = db.session.query(App).get(api_token.app_id)
  16. if not app_model:
  17. raise NotFound()
  18. if app_model.status != 'normal':
  19. raise NotFound()
  20. if not app_model.enable_api:
  21. raise NotFound()
  22. return view(app_model, None, *args, **kwargs)
  23. return decorated
  24. if view:
  25. return decorator(view)
  26. # if view is None, it means that the decorator is used without parentheses
  27. # use the decorator as a function for method_decorators
  28. return decorator
  29. def validate_dataset_token(view=None):
  30. def decorator(view):
  31. @wraps(view)
  32. def decorated(*args, **kwargs):
  33. api_token = validate_and_get_api_token('dataset')
  34. dataset = db.session.query(Dataset).get(api_token.dataset_id)
  35. if not dataset:
  36. raise NotFound()
  37. return view(dataset, *args, **kwargs)
  38. return decorated
  39. if view:
  40. return decorator(view)
  41. # if view is None, it means that the decorator is used without parentheses
  42. # use the decorator as a function for method_decorators
  43. return decorator
  44. def validate_and_get_api_token(scope=None):
  45. """
  46. Validate and get API token.
  47. """
  48. auth_header = request.headers.get('Authorization')
  49. if auth_header is None:
  50. raise Unauthorized()
  51. auth_scheme, auth_token = auth_header.split(None, 1)
  52. auth_scheme = auth_scheme.lower()
  53. if auth_scheme != 'bearer':
  54. raise Unauthorized()
  55. api_token = db.session.query(ApiToken).filter(
  56. ApiToken.token == auth_token,
  57. ApiToken.type == scope,
  58. ).first()
  59. if not api_token:
  60. raise Unauthorized()
  61. api_token.last_used_at = datetime.utcnow()
  62. db.session.commit()
  63. return api_token
  64. class AppApiResource(Resource):
  65. method_decorators = [validate_app_token]
  66. class DatasetApiResource(Resource):
  67. method_decorators = [validate_dataset_token]