wraps.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. from collections.abc import Callable
  2. from functools import wraps
  3. from typing import Optional
  4. from flask import request
  5. from flask_restful import reqparse # type: ignore
  6. from pydantic import BaseModel
  7. from sqlalchemy.orm import Session
  8. from extensions.ext_database import db
  9. from models.account import Account, Tenant
  10. from models.model import EndUser
  11. from services.account_service import AccountService
  12. def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser:
  13. try:
  14. with Session(db.engine) as session:
  15. if not user_id:
  16. user_id = "DEFAULT-USER"
  17. if user_id == "DEFAULT-USER":
  18. user_model = session.query(EndUser).filter(EndUser.session_id == "DEFAULT-USER").first()
  19. if not user_model:
  20. user_model = EndUser(
  21. tenant_id=tenant_id,
  22. type="service_api",
  23. is_anonymous=True if user_id == "DEFAULT-USER" else False,
  24. session_id=user_id,
  25. )
  26. session.add(user_model)
  27. session.commit()
  28. else:
  29. user_model = AccountService.load_user(user_id)
  30. if not user_model:
  31. user_model = session.query(EndUser).filter(EndUser.id == user_id).first()
  32. if not user_model:
  33. raise ValueError("user not found")
  34. except Exception:
  35. raise ValueError("user not found")
  36. return user_model
  37. def get_user_tenant(view: Optional[Callable] = None):
  38. def decorator(view_func):
  39. @wraps(view_func)
  40. def decorated_view(*args, **kwargs):
  41. # fetch json body
  42. parser = reqparse.RequestParser()
  43. parser.add_argument("tenant_id", type=str, required=True, location="json")
  44. parser.add_argument("user_id", type=str, required=True, location="json")
  45. kwargs = parser.parse_args()
  46. user_id = kwargs.get("user_id")
  47. tenant_id = kwargs.get("tenant_id")
  48. if not tenant_id:
  49. raise ValueError("tenant_id is required")
  50. if not user_id:
  51. user_id = "DEFAULT-USER"
  52. del kwargs["tenant_id"]
  53. del kwargs["user_id"]
  54. try:
  55. tenant_model = (
  56. db.session.query(Tenant)
  57. .filter(
  58. Tenant.id == tenant_id,
  59. )
  60. .first()
  61. )
  62. except Exception:
  63. raise ValueError("tenant not found")
  64. if not tenant_model:
  65. raise ValueError("tenant not found")
  66. kwargs["tenant_model"] = tenant_model
  67. kwargs["user_model"] = get_user(tenant_id, user_id)
  68. return view_func(*args, **kwargs)
  69. return decorated_view
  70. if view is None:
  71. return decorator
  72. else:
  73. return decorator(view)
  74. def plugin_data(view: Optional[Callable] = None, *, payload_type: type[BaseModel]):
  75. def decorator(view_func):
  76. def decorated_view(*args, **kwargs):
  77. try:
  78. data = request.get_json()
  79. except Exception:
  80. raise ValueError("invalid json")
  81. try:
  82. payload = payload_type(**data)
  83. except Exception as e:
  84. raise ValueError(f"invalid payload: {str(e)}")
  85. kwargs["payload"] = payload
  86. return view_func(*args, **kwargs)
  87. return decorated_view
  88. if view is None:
  89. return decorator
  90. else:
  91. return decorator(view)