|
@@ -1,8 +1,7 @@
|
|
# -*- coding:utf-8 -*-
|
|
# -*- coding:utf-8 -*-
|
|
import os
|
|
import os
|
|
-from datetime import datetime, timedelta
|
|
|
|
|
|
|
|
-from werkzeug.exceptions import Forbidden
|
|
|
|
|
|
+from werkzeug.exceptions import Unauthorized
|
|
|
|
|
|
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
|
|
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
|
|
from gevent import monkey
|
|
from gevent import monkey
|
|
@@ -12,12 +11,11 @@ import logging
|
|
import json
|
|
import json
|
|
import threading
|
|
import threading
|
|
|
|
|
|
-from flask import Flask, request, Response, session
|
|
|
|
-import flask_login
|
|
|
|
|
|
+from flask import Flask, request, Response
|
|
from flask_cors import CORS
|
|
from flask_cors import CORS
|
|
|
|
|
|
from core.model_providers.providers import hosted
|
|
from core.model_providers.providers import hosted
|
|
-from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
|
|
|
|
|
|
+from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
|
|
ext_database, ext_storage, ext_mail, ext_stripe
|
|
ext_database, ext_storage, ext_mail, ext_stripe
|
|
from extensions.ext_database import db
|
|
from extensions.ext_database import db
|
|
from extensions.ext_login import login_manager
|
|
from extensions.ext_login import login_manager
|
|
@@ -27,12 +25,10 @@ from models import model, account, dataset, web, task, source, tool
|
|
from events import event_handlers
|
|
from events import event_handlers
|
|
# DO NOT REMOVE ABOVE
|
|
# DO NOT REMOVE ABOVE
|
|
|
|
|
|
-import core
|
|
|
|
from config import Config, CloudEditionConfig
|
|
from config import Config, CloudEditionConfig
|
|
from commands import register_commands
|
|
from commands import register_commands
|
|
-from models.account import TenantAccountJoin, AccountStatus
|
|
|
|
-from models.model import Account, EndUser, App
|
|
|
|
-from services.account_service import TenantService
|
|
|
|
|
|
+from services.account_service import AccountService
|
|
|
|
+from libs.passport import PassportService
|
|
|
|
|
|
import warnings
|
|
import warnings
|
|
warnings.simplefilter("ignore", ResourceWarning)
|
|
warnings.simplefilter("ignore", ResourceWarning)
|
|
@@ -85,81 +81,33 @@ def initialize_extensions(app):
|
|
ext_redis.init_app(app)
|
|
ext_redis.init_app(app)
|
|
ext_storage.init_app(app)
|
|
ext_storage.init_app(app)
|
|
ext_celery.init_app(app)
|
|
ext_celery.init_app(app)
|
|
- ext_session.init_app(app)
|
|
|
|
ext_login.init_app(app)
|
|
ext_login.init_app(app)
|
|
ext_mail.init_app(app)
|
|
ext_mail.init_app(app)
|
|
ext_sentry.init_app(app)
|
|
ext_sentry.init_app(app)
|
|
ext_stripe.init_app(app)
|
|
ext_stripe.init_app(app)
|
|
|
|
|
|
|
|
|
|
-def _create_tenant_for_account(account):
|
|
|
|
- tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
|
|
|
-
|
|
|
|
- TenantService.create_tenant_member(tenant, account, role='owner')
|
|
|
|
- account.current_tenant = tenant
|
|
|
|
-
|
|
|
|
- return tenant
|
|
|
|
-
|
|
|
|
-
|
|
|
|
# Flask-Login configuration
|
|
# Flask-Login configuration
|
|
-@login_manager.user_loader
|
|
|
|
-def load_user(user_id):
|
|
|
|
- """Load user based on the user_id."""
|
|
|
|
|
|
+@login_manager.request_loader
|
|
|
|
+def load_user_from_request(request_from_flask_login):
|
|
|
|
+ """Load user based on the request."""
|
|
if request.blueprint == 'console':
|
|
if request.blueprint == 'console':
|
|
# Check if the user_id contains a dot, indicating the old format
|
|
# Check if the user_id contains a dot, indicating the old format
|
|
- if '.' in user_id:
|
|
|
|
- tenant_id, account_id = user_id.split('.')
|
|
|
|
- else:
|
|
|
|
- account_id = user_id
|
|
|
|
-
|
|
|
|
- account = db.session.query(Account).filter(Account.id == account_id).first()
|
|
|
|
-
|
|
|
|
- if account:
|
|
|
|
- if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
|
|
|
|
- raise Forbidden('Account is banned or closed.')
|
|
|
|
-
|
|
|
|
- workspace_id = session.get('workspace_id')
|
|
|
|
- if workspace_id:
|
|
|
|
- tenant_account_join = db.session.query(TenantAccountJoin).filter(
|
|
|
|
- TenantAccountJoin.account_id == account.id,
|
|
|
|
- TenantAccountJoin.tenant_id == workspace_id
|
|
|
|
- ).first()
|
|
|
|
-
|
|
|
|
- if not tenant_account_join:
|
|
|
|
- tenant_account_join = db.session.query(TenantAccountJoin).filter(
|
|
|
|
- TenantAccountJoin.account_id == account.id).first()
|
|
|
|
-
|
|
|
|
- if tenant_account_join:
|
|
|
|
- account.current_tenant_id = tenant_account_join.tenant_id
|
|
|
|
- else:
|
|
|
|
- _create_tenant_for_account(account)
|
|
|
|
- session['workspace_id'] = account.current_tenant_id
|
|
|
|
- else:
|
|
|
|
- account.current_tenant_id = workspace_id
|
|
|
|
- else:
|
|
|
|
- tenant_account_join = db.session.query(TenantAccountJoin).filter(
|
|
|
|
- TenantAccountJoin.account_id == account.id).first()
|
|
|
|
- if tenant_account_join:
|
|
|
|
- account.current_tenant_id = tenant_account_join.tenant_id
|
|
|
|
- else:
|
|
|
|
- _create_tenant_for_account(account)
|
|
|
|
- session['workspace_id'] = account.current_tenant_id
|
|
|
|
-
|
|
|
|
- current_time = datetime.utcnow()
|
|
|
|
-
|
|
|
|
- # update last_active_at when last_active_at is more than 10 minutes ago
|
|
|
|
- if current_time - account.last_active_at > timedelta(minutes=10):
|
|
|
|
- account.last_active_at = current_time
|
|
|
|
- db.session.commit()
|
|
|
|
-
|
|
|
|
- # Log in the user with the updated user_id
|
|
|
|
- flask_login.login_user(account, remember=True)
|
|
|
|
-
|
|
|
|
- return account
|
|
|
|
|
|
+ auth_header = request.headers.get('Authorization', '')
|
|
|
|
+ if ' ' not in auth_header:
|
|
|
|
+ raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
|
|
|
+ auth_scheme, auth_token = auth_header.split(None, 1)
|
|
|
|
+ auth_scheme = auth_scheme.lower()
|
|
|
|
+ if auth_scheme != 'bearer':
|
|
|
|
+ raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
|
|
|
+
|
|
|
|
+ decoded = PassportService().verify(auth_token)
|
|
|
|
+ user_id = decoded.get('user_id')
|
|
|
|
+
|
|
|
|
+ return AccountService.load_user(user_id)
|
|
else:
|
|
else:
|
|
return None
|
|
return None
|
|
|
|
|
|
-
|
|
|
|
@login_manager.unauthorized_handler
|
|
@login_manager.unauthorized_handler
|
|
def unauthorized_handler():
|
|
def unauthorized_handler():
|
|
"""Handle unauthorized requests."""
|
|
"""Handle unauthorized requests."""
|
|
@@ -216,6 +164,7 @@ if app.config['TESTING']:
|
|
@app.after_request
|
|
@app.after_request
|
|
def after_request(response):
|
|
def after_request(response):
|
|
"""Add Version headers to the response."""
|
|
"""Add Version headers to the response."""
|
|
|
|
+ response.set_cookie('remember_token', '', expires=0)
|
|
response.headers.add('X-Version', app.config['CURRENT_VERSION'])
|
|
response.headers.add('X-Version', app.config['CURRENT_VERSION'])
|
|
response.headers.add('X-Env', app.config['DEPLOY_ENV'])
|
|
response.headers.add('X-Env', app.config['DEPLOY_ENV'])
|
|
return response
|
|
return response
|