billing_service.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import os
  2. from typing import Literal, Optional
  3. import httpx
  4. from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
  5. from extensions.ext_database import db
  6. from libs.helper import RateLimiter
  7. from models.account import TenantAccountJoin, TenantAccountRole
  8. class BillingService:
  9. base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
  10. secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY")
  11. compliance_download_rate_limiter = RateLimiter("compliance_download_rate_limiter", 4, 60)
  12. @classmethod
  13. def get_info(cls, tenant_id: str):
  14. params = {"tenant_id": tenant_id}
  15. billing_info = cls._send_request("GET", "/subscription/info", params=params)
  16. return billing_info
  17. @classmethod
  18. def get_subscription(cls, plan: str, interval: str, prefilled_email: str = "", tenant_id: str = ""):
  19. params = {"plan": plan, "interval": interval, "prefilled_email": prefilled_email, "tenant_id": tenant_id}
  20. return cls._send_request("GET", "/subscription/payment-link", params=params)
  21. @classmethod
  22. def get_model_provider_payment_link(cls, provider_name: str, tenant_id: str, account_id: str, prefilled_email: str):
  23. params = {
  24. "provider_name": provider_name,
  25. "tenant_id": tenant_id,
  26. "account_id": account_id,
  27. "prefilled_email": prefilled_email,
  28. }
  29. return cls._send_request("GET", "/model-provider/payment-link", params=params)
  30. @classmethod
  31. def get_invoices(cls, prefilled_email: str = "", tenant_id: str = ""):
  32. params = {"prefilled_email": prefilled_email, "tenant_id": tenant_id}
  33. return cls._send_request("GET", "/invoices", params=params)
  34. @classmethod
  35. @retry(
  36. wait=wait_fixed(2),
  37. stop=stop_before_delay(10),
  38. retry=retry_if_exception_type(httpx.RequestError),
  39. reraise=True,
  40. )
  41. def _send_request(cls, method: Literal["GET", "POST", "DELETE"], endpoint: str, json=None, params=None):
  42. headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
  43. url = f"{cls.base_url}{endpoint}"
  44. response = httpx.request(method, url, json=json, params=params, headers=headers)
  45. if method == "GET" and response.status_code != httpx.codes.OK:
  46. raise ValueError("Unable to retrieve billing information. Please try again later or contact support.")
  47. return response.json()
  48. @staticmethod
  49. def is_tenant_owner_or_admin(current_user):
  50. tenant_id = current_user.current_tenant_id
  51. join: Optional[TenantAccountJoin] = (
  52. db.session.query(TenantAccountJoin)
  53. .filter(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id)
  54. .first()
  55. )
  56. if not join:
  57. raise ValueError("Tenant account join not found")
  58. if not TenantAccountRole.is_privileged_role(join.role):
  59. raise ValueError("Only team owner or team admin can perform this action")
  60. @classmethod
  61. def delete_account(cls, account_id: str):
  62. """Delete account."""
  63. params = {"account_id": account_id}
  64. return cls._send_request("DELETE", "/account/", params=params)
  65. @classmethod
  66. def is_email_in_freeze(cls, email: str) -> bool:
  67. params = {"email": email}
  68. try:
  69. response = cls._send_request("GET", "/account/in-freeze", params=params)
  70. return bool(response.get("data", False))
  71. except Exception:
  72. return False
  73. @classmethod
  74. def update_account_deletion_feedback(cls, email: str, feedback: str):
  75. """Update account deletion feedback."""
  76. json = {"email": email, "feedback": feedback}
  77. return cls._send_request("POST", "/account/delete-feedback", json=json)
  78. @classmethod
  79. def get_compliance_download_link(
  80. cls,
  81. doc_name: str,
  82. account_id: str,
  83. tenant_id: str,
  84. ip: str,
  85. device_info: str,
  86. ):
  87. limiter_key = f"{account_id}:{tenant_id}"
  88. if cls.compliance_download_rate_limiter.is_rate_limited(limiter_key):
  89. from controllers.console.error import CompilanceRateLimitError
  90. raise CompilanceRateLimitError()
  91. json = {
  92. "doc_name": doc_name,
  93. "account_id": account_id,
  94. "tenant_id": tenant_id,
  95. "ip_address": ip,
  96. "device_info": device_info,
  97. }
  98. res = cls._send_request("POST", "/compliance/download", json=json)
  99. cls.compliance_download_rate_limiter.increment_rate_limit(limiter_key)
  100. return res