|
@@ -2,13 +2,16 @@
|
|
|
import base64
|
|
|
import logging
|
|
|
import secrets
|
|
|
+import uuid
|
|
|
from datetime import datetime
|
|
|
+from hashlib import sha256
|
|
|
from typing import Optional
|
|
|
|
|
|
from flask import session
|
|
|
from sqlalchemy import func
|
|
|
|
|
|
from events.tenant_event import tenant_was_created
|
|
|
+from extensions.ext_redis import redis_client
|
|
|
from services.errors.account import AccountLoginError, CurrentPasswordIncorrectError, LinkAccountIntegrateError, \
|
|
|
TenantNotFound, AccountNotLinkTenantError, InvalidActionError, CannotOperateSelfError, MemberNotInTenantError, \
|
|
|
RoleAlreadyAssignedError, NoPermissionError, AccountRegisterError, AccountAlreadyInTenantError
|
|
@@ -16,6 +19,7 @@ from libs.helper import get_remote_ip
|
|
|
from libs.password import compare_password, hash_password
|
|
|
from libs.rsa import generate_key_pair
|
|
|
from models.account import *
|
|
|
+from tasks.mail_invite_member_task import send_invite_member_mail_task
|
|
|
|
|
|
|
|
|
class AccountService:
|
|
@@ -48,12 +52,18 @@ class AccountService:
|
|
|
@staticmethod
|
|
|
def update_account_password(account, password, new_password):
|
|
|
"""update account password"""
|
|
|
- # todo: split validation and update
|
|
|
if account.password and not compare_password(password, account.password, account.password_salt):
|
|
|
raise CurrentPasswordIncorrectError("Current password is incorrect.")
|
|
|
- password_hashed = hash_password(new_password, account.password_salt)
|
|
|
+
|
|
|
+ # generate password salt
|
|
|
+ salt = secrets.token_bytes(16)
|
|
|
+ base64_salt = base64.b64encode(salt).decode()
|
|
|
+
|
|
|
+ # encrypt password with salt
|
|
|
+ password_hashed = hash_password(new_password, salt)
|
|
|
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
|
|
account.password = base64_password_hashed
|
|
|
+ account.password_salt = base64_salt
|
|
|
db.session.commit()
|
|
|
return account
|
|
|
|
|
@@ -283,8 +293,6 @@ class TenantService:
|
|
|
@staticmethod
|
|
|
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None:
|
|
|
"""Remove member from tenant"""
|
|
|
- # todo: check permission
|
|
|
-
|
|
|
if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, 'remove'):
|
|
|
raise CannotOperateSelfError("Cannot operate self.")
|
|
|
|
|
@@ -293,6 +301,12 @@ class TenantService:
|
|
|
raise MemberNotInTenantError("Member not in tenant.")
|
|
|
|
|
|
db.session.delete(ta)
|
|
|
+
|
|
|
+ account.initialized_at = None
|
|
|
+ account.status = AccountStatus.PENDING.value
|
|
|
+ account.password = None
|
|
|
+ account.password_salt = None
|
|
|
+
|
|
|
db.session.commit()
|
|
|
|
|
|
@staticmethod
|
|
@@ -332,8 +346,8 @@ class TenantService:
|
|
|
|
|
|
class RegisterService:
|
|
|
|
|
|
- @staticmethod
|
|
|
- def register(email, name, password: str = None, open_id: str = None, provider: str = None) -> Account:
|
|
|
+ @classmethod
|
|
|
+ def register(cls, email, name, password: str = None, open_id: str = None, provider: str = None) -> Account:
|
|
|
db.session.begin_nested()
|
|
|
"""Register account"""
|
|
|
try:
|
|
@@ -359,9 +373,9 @@ class RegisterService:
|
|
|
|
|
|
return account
|
|
|
|
|
|
- @staticmethod
|
|
|
- def invite_new_member(tenant: Tenant, email: str, role: str = 'normal',
|
|
|
- inviter: Account = None) -> TenantAccountJoin:
|
|
|
+ @classmethod
|
|
|
+ def invite_new_member(cls, tenant: Tenant, email: str, role: str = 'normal',
|
|
|
+ inviter: Account = None) -> str:
|
|
|
"""Invite new member"""
|
|
|
account = Account.query.filter_by(email=email).first()
|
|
|
|
|
@@ -380,5 +394,71 @@ class RegisterService:
|
|
|
if ta:
|
|
|
raise AccountAlreadyInTenantError("Account already in tenant.")
|
|
|
|
|
|
- ta = TenantService.create_tenant_member(tenant, account, role)
|
|
|
- return ta
|
|
|
+ TenantService.create_tenant_member(tenant, account, role)
|
|
|
+
|
|
|
+ token = cls.generate_invite_token(tenant, account)
|
|
|
+
|
|
|
+ # send email
|
|
|
+ send_invite_member_mail_task.delay(
|
|
|
+ to=email,
|
|
|
+ token=cls.generate_invite_token(tenant, account),
|
|
|
+ inviter_name=inviter.name if inviter else 'Dify',
|
|
|
+ workspace_id=tenant.id,
|
|
|
+ workspace_name=tenant.name,
|
|
|
+ )
|
|
|
+
|
|
|
+ return token
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def generate_invite_token(cls, tenant: Tenant, account: Account) -> str:
|
|
|
+ token = str(uuid.uuid4())
|
|
|
+ email_hash = sha256(account.email.encode()).hexdigest()
|
|
|
+ cache_key = 'member_invite_token:{}, {}:{}'.format(str(tenant.id), email_hash, token)
|
|
|
+ redis_client.setex(cache_key, 3600, str(account.id))
|
|
|
+ return token
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def revoke_token(cls, workspace_id: str, email: str, token: str):
|
|
|
+ email_hash = sha256(email.encode()).hexdigest()
|
|
|
+ cache_key = 'member_invite_token:{}, {}:{}'.format(workspace_id, email_hash, token)
|
|
|
+ redis_client.delete(cache_key)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def get_account_if_token_valid(cls, workspace_id: str, email: str, token: str) -> Optional[Account]:
|
|
|
+ tenant = db.session.query(Tenant).filter(
|
|
|
+ Tenant.id == workspace_id,
|
|
|
+ Tenant.status == 'normal'
|
|
|
+ ).first()
|
|
|
+
|
|
|
+ if not tenant:
|
|
|
+ return None
|
|
|
+
|
|
|
+ tenant_account = db.session.query(Account, TenantAccountJoin.role).join(
|
|
|
+ TenantAccountJoin, Account.id == TenantAccountJoin.account_id
|
|
|
+ ).filter(Account.email == email, TenantAccountJoin.tenant_id == tenant.id).first()
|
|
|
+
|
|
|
+ if not tenant_account:
|
|
|
+ return None
|
|
|
+
|
|
|
+ account_id = cls._get_account_id_by_invite_token(workspace_id, email, token)
|
|
|
+ if not account_id:
|
|
|
+ return None
|
|
|
+
|
|
|
+ account = tenant_account[0]
|
|
|
+ if not account:
|
|
|
+ return None
|
|
|
+
|
|
|
+ if account_id != str(account.id):
|
|
|
+ return None
|
|
|
+
|
|
|
+ return account
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _get_account_id_by_invite_token(cls, workspace_id: str, email: str, token: str) -> Optional[str]:
|
|
|
+ email_hash = sha256(email.encode()).hexdigest()
|
|
|
+ cache_key = 'member_invite_token:{}, {}:{}'.format(workspace_id, email_hash, token)
|
|
|
+ account_id = redis_client.get(cache_key)
|
|
|
+ if not account_id:
|
|
|
+ return None
|
|
|
+
|
|
|
+ return account_id.decode('utf-8')
|