ソースを参照

Feat/move tenant id into db (#2341)

crazywoola 1 年間 前
コミット
a8f23ed712

+ 1 - 7
api/controllers/console/auth/login.py

@@ -8,7 +8,7 @@ from flask import current_app, request
 from flask_restful import Resource, reqparse
 from libs.helper import email
 from libs.password import valid_password
-from services.account_service import AccountService, TenantService
+from services.account_service import AccountService
 
 
 class LoginApi(Resource):
@@ -30,11 +30,6 @@ class LoginApi(Resource):
         except services.errors.account.AccountLoginError:
             return {'code': 'unauthorized', 'message': 'Invalid email or password'}, 401
 
-        try:
-            TenantService.switch_tenant(account)
-        except Exception:
-            pass
-
         AccountService.update_last_login(account, request)
 
         # todo: return the user info
@@ -47,7 +42,6 @@ class LogoutApi(Resource):
 
     @setup_required
     def get(self):
-        flask.session.pop('workspace_id', None)
         flask_login.logout_user()
         return {'result': 'success'}
 

+ 32 - 0
api/migrations/versions/16830a790f0f_.py

@@ -0,0 +1,32 @@
+"""empty message
+
+Revision ID: 16830a790f0f
+Revises: 380c6aa5a70d
+Create Date: 2024-02-01 08:21:31.111119
+
+"""
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '16830a790f0f'
+down_revision = '380c6aa5a70d'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('tenant_account_joins', schema=None) as batch_op:
+        batch_op.add_column(sa.Column('current', sa.Boolean(), server_default=sa.text('false'), nullable=False))
+
+    # ### end Alembic commands ###
+
+
+def downgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('tenant_account_joins', schema=None) as batch_op:
+        batch_op.drop_column('current')
+
+    # ### end Alembic commands ###

+ 1 - 1
api/models/account.py

@@ -1,6 +1,5 @@
 import enum
 import json
-from math import e
 from typing import List
 
 from extensions.ext_database import db
@@ -155,6 +154,7 @@ class TenantAccountJoin(db.Model):
     id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
     tenant_id = db.Column(UUID, nullable=False)
     account_id = db.Column(UUID, nullable=False)
+    current = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
     role = db.Column(db.String(16), nullable=False, server_default='normal')
     invited_by = db.Column(UUID, nullable=True)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))

+ 38 - 55
api/services/account_service.py

@@ -11,7 +11,7 @@ from typing import Any, Dict, Optional
 from constants.languages import language_timezone_mapping, languages
 from events.tenant_event import tenant_was_created
 from extensions.ext_redis import redis_client
-from flask import current_app, session
+from flask import current_app
 from libs.helper import get_remote_ip
 from libs.passport import PassportService
 from libs.password import compare_password, hash_password
@@ -23,7 +23,8 @@ from services.errors.account import (AccountAlreadyInTenantError, AccountLoginEr
                                      NoPermissionError, RoleAlreadyAssignedError, TenantNotFound)
 from sqlalchemy import func
 from tasks.mail_invite_member_task import send_invite_member_mail_task
-from werkzeug.exceptions import Forbidden, Unauthorized
+from werkzeug.exceptions import Forbidden
+from sqlalchemy import exc
 
 
 def _create_tenant_for_account(account) -> Tenant:
@@ -39,54 +40,33 @@ class AccountService:
 
     @staticmethod
     def load_user(user_id: str) -> Account:
-        # todo: used by flask_login
-        if '.' in user_id:
-            tenant_id, account_id = user_id.split('.')
-        else:
-            account_id = user_id
-
-        account = db.session.query(Account).filter(Account.id == account_id).first()
-
-        if account:
-            if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
-                raise Forbidden('Account is banned or closed.')
-
-            workspace_id = session.get('workspace_id')
-            if workspace_id:
-                tenant_account_join = db.session.query(TenantAccountJoin).filter(
-                    TenantAccountJoin.account_id == account.id,
-                    TenantAccountJoin.tenant_id == workspace_id
-                ).first()
-
-                if not tenant_account_join:
-                    tenant_account_join = db.session.query(TenantAccountJoin).filter(
-                        TenantAccountJoin.account_id == account.id).first()
-
-                    if tenant_account_join:
-                        account.current_tenant_id = tenant_account_join.tenant_id
-                    else:
-                        _create_tenant_for_account(account)
-                    session['workspace_id'] = account.current_tenant_id
-                else:
-                    account.current_tenant_id = workspace_id
-            else:
-                tenant_account_join = db.session.query(TenantAccountJoin).filter(
-                    TenantAccountJoin.account_id == account.id).first()
-                if tenant_account_join:
-                    account.current_tenant_id = tenant_account_join.tenant_id
-                else:
-                    _create_tenant_for_account(account)
-                session['workspace_id'] = account.current_tenant_id
-
-            current_time = datetime.utcnow()
+        account = Account.query.filter_by(id=user_id).first()
+        if not account:
+            return None
 
-            # update last_active_at when last_active_at is more than 10 minutes ago
-            if current_time - account.last_active_at > timedelta(minutes=10):
-                account.last_active_at = current_time
-                db.session.commit()
+        if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]:
+            raise Forbidden('Account is banned or closed.')
+        
+        # init owner's tenant
+        tenant_owner = TenantAccountJoin.query.filter_by(account_id=account.id, role='owner').first()
+        if not tenant_owner:
+            _create_tenant_for_account(account)
+        
+        current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
+        if current_tenant:
+            account.current_tenant_id = current_tenant.tenant_id
+        else:
+            account.current_tenant_id = tenant_owner.tenant_id
+            tenant_owner.current = True
+            db.session.commit()
+       
+        if datetime.utcnow() - account.last_active_at > timedelta(minutes=10):
+            account.last_active_at = datetime.utcnow()
+            db.session.commit()
 
         return account
 
+
     @staticmethod
     def get_account_jwt_token(account):
         payload = {
@@ -277,18 +257,21 @@ class TenantService:
     @staticmethod
     def switch_tenant(account: Account, tenant_id: int = None) -> None:
         """Switch the current workspace for the account"""
-        if not tenant_id:
-            tenant_account_join = TenantAccountJoin.query.filter_by(account_id=account.id).first()
-        else:
-            tenant_account_join = TenantAccountJoin.query.filter_by(account_id=account.id, tenant_id=tenant_id).first()
 
-        # Check if the tenant exists and the account is a member of the tenant
+        tenant_account_join = TenantAccountJoin.query.filter_by(account_id=account.id, tenant_id=tenant_id).first()
         if not tenant_account_join:
             raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
-
-        # Set the current tenant for the account
-        account.current_tenant_id = tenant_account_join.tenant_id
-        session['workspace_id'] = account.current_tenant.id
+        else: 
+            with db.session.begin():
+                try:
+                    TenantAccountJoin.query.filter_by(account_id=account.id).update({'current': False})
+                    tenant_account_join.current = True
+                    db.session.commit()
+                    # Set the current tenant for the account
+                    account.current_tenant_id = tenant_account_join.tenant_id
+                except exc.SQLAlchemyError:
+                    db.session.rollback()
+                    raise
 
     @staticmethod
     def get_tenant_members(tenant: Tenant) -> List[Account]: