oauth.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import json
  2. import urllib.parse
  3. from dataclasses import dataclass
  4. import requests
  5. from extensions.ext_database import db
  6. from flask_login import current_user
  7. from models.source import DataSourceBinding
  8. @dataclass
  9. class OAuthUserInfo:
  10. id: str
  11. name: str
  12. email: str
  13. class OAuth:
  14. def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
  15. self.client_id = client_id
  16. self.client_secret = client_secret
  17. self.redirect_uri = redirect_uri
  18. def get_authorization_url(self):
  19. raise NotImplementedError()
  20. def get_access_token(self, code: str):
  21. raise NotImplementedError()
  22. def get_raw_user_info(self, token: str):
  23. raise NotImplementedError()
  24. def get_user_info(self, token: str) -> OAuthUserInfo:
  25. raw_info = self.get_raw_user_info(token)
  26. return self._transform_user_info(raw_info)
  27. def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
  28. raise NotImplementedError()
  29. class GitHubOAuth(OAuth):
  30. _AUTH_URL = 'https://github.com/login/oauth/authorize'
  31. _TOKEN_URL = 'https://github.com/login/oauth/access_token'
  32. _USER_INFO_URL = 'https://api.github.com/user'
  33. _EMAIL_INFO_URL = 'https://api.github.com/user/emails'
  34. def get_authorization_url(self):
  35. params = {
  36. 'client_id': self.client_id,
  37. 'redirect_uri': self.redirect_uri,
  38. 'scope': 'user:email' # Request only basic user information
  39. }
  40. return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
  41. def get_access_token(self, code: str):
  42. data = {
  43. 'client_id': self.client_id,
  44. 'client_secret': self.client_secret,
  45. 'code': code,
  46. 'redirect_uri': self.redirect_uri
  47. }
  48. headers = {'Accept': 'application/json'}
  49. response = requests.post(self._TOKEN_URL, data=data, headers=headers)
  50. response_json = response.json()
  51. access_token = response_json.get('access_token')
  52. if not access_token:
  53. raise ValueError(f"Error in GitHub OAuth: {response_json}")
  54. return access_token
  55. def get_raw_user_info(self, token: str):
  56. headers = {'Authorization': f"token {token}"}
  57. response = requests.get(self._USER_INFO_URL, headers=headers)
  58. response.raise_for_status()
  59. user_info = response.json()
  60. email_response = requests.get(self._EMAIL_INFO_URL, headers=headers)
  61. email_info = email_response.json()
  62. primary_email = next((email for email in email_info if email['primary'] == True), None)
  63. return {**user_info, 'email': primary_email['email']}
  64. def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
  65. email = raw_info.get('email')
  66. if not email:
  67. email = f"{raw_info['id']}+{raw_info['login']}@users.noreply.github.com"
  68. return OAuthUserInfo(
  69. id=str(raw_info['id']),
  70. name=raw_info['name'],
  71. email=email
  72. )
  73. class GoogleOAuth(OAuth):
  74. _AUTH_URL = 'https://accounts.google.com/o/oauth2/v2/auth'
  75. _TOKEN_URL = 'https://oauth2.googleapis.com/token'
  76. _USER_INFO_URL = 'https://www.googleapis.com/oauth2/v3/userinfo'
  77. def get_authorization_url(self):
  78. params = {
  79. 'client_id': self.client_id,
  80. 'response_type': 'code',
  81. 'redirect_uri': self.redirect_uri,
  82. 'scope': 'openid email'
  83. }
  84. return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
  85. def get_access_token(self, code: str):
  86. data = {
  87. 'client_id': self.client_id,
  88. 'client_secret': self.client_secret,
  89. 'code': code,
  90. 'grant_type': 'authorization_code',
  91. 'redirect_uri': self.redirect_uri
  92. }
  93. headers = {'Accept': 'application/json'}
  94. response = requests.post(self._TOKEN_URL, data=data, headers=headers)
  95. response_json = response.json()
  96. access_token = response_json.get('access_token')
  97. if not access_token:
  98. raise ValueError(f"Error in Google OAuth: {response_json}")
  99. return access_token
  100. def get_raw_user_info(self, token: str):
  101. headers = {'Authorization': f"Bearer {token}"}
  102. response = requests.get(self._USER_INFO_URL, headers=headers)
  103. response.raise_for_status()
  104. return response.json()
  105. def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
  106. return OAuthUserInfo(
  107. id=str(raw_info['sub']),
  108. name=None,
  109. email=raw_info['email']
  110. )