Selaa lähdekoodia

feat: add billing switch. (#1789)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Garfield Dai 1 vuosi sitten
vanhempi
commit
7b37e05dec

+ 4 - 0
api/config.py

@@ -55,6 +55,8 @@ DEFAULTS = {
     'OUTPUT_MODERATION_BUFFER_SIZE': 300,
     'MULTIMODAL_SEND_IMAGE_FORMAT': 'base64',
     'INVITE_EXPIRY_HOURS': 72,
+    'BILLING_ENABLED': 'False',
+    'CAN_REPLACE_LOGO': 'False',
     'ETL_TYPE': 'dify',
 }
 
@@ -279,6 +281,8 @@ class Config:
 
         self.ETL_TYPE = get_env('ETL_TYPE')
         self.UNSTRUCTURED_API_URL = get_env('UNSTRUCTURED_API_URL')
+        self.BILLING_ENABLED = get_bool_env('BILLING_ENABLED')
+        self.CAN_REPLACE_LOGO = get_bool_env('CAN_REPLACE_LOGO')
 
 
 class CloudEditionConfig(Config):

+ 1 - 1
api/controllers/console/__init__.py

@@ -6,7 +6,7 @@ bp = Blueprint('console', __name__, url_prefix='/console/api')
 api = ExternalApi(bp)
 
 # Import other controllers
-from . import extension, setup, version, apikey, admin
+from . import extension, setup, version, apikey, admin, feature
 
 # Import app controllers
 from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio, annotation

+ 0 - 16
api/controllers/console/billing/billing.py

@@ -1,6 +1,5 @@
 from flask_restful import Resource, reqparse
 from flask_login import current_user
-from flask import current_app
 
 from controllers.console import api
 from controllers.console.setup import setup_required
@@ -10,20 +9,6 @@ from libs.login import login_required
 from services.billing_service import BillingService
 
 
-class BillingInfo(Resource):
-
-    @setup_required
-    @login_required
-    @account_initialization_required
-    def get(self):
-
-        edition = current_app.config['EDITION']
-        if edition != 'CLOUD':
-            return {"enabled": False}
-
-        return BillingService.get_info(current_user.current_tenant_id)
-
-
 class Subscription(Resource):
 
     @setup_required
@@ -56,6 +41,5 @@ class Invoices(Resource):
         return BillingService.get_invoices(current_user.email)
 
 
-api.add_resource(BillingInfo, '/billing/info')
 api.add_resource(Subscription, '/billing/subscription')
 api.add_resource(Invoices, '/billing/invoices')

+ 14 - 0
api/controllers/console/feature.py

@@ -0,0 +1,14 @@
+from flask_restful import Resource
+from flask_login import current_user
+
+from . import api
+from services.feature_service import FeatureService
+
+
+class FeatureApi(Resource):
+
+    def get(self):
+        return FeatureService.get_features(current_user.current_tenant_id).dict()
+
+
+api.add_resource(FeatureApi, '/features')

+ 14 - 14
api/controllers/console/wraps.py

@@ -5,7 +5,7 @@ from flask import current_app, abort
 from flask_login import current_user
 
 from controllers.console.workspace.error import AccountNotInitializedError
-from services.billing_service import BillingService
+from services.feature_service import FeatureService
 
 
 def account_initialization_required(view):
@@ -49,23 +49,23 @@ def cloud_edition_billing_resource_check(resource: str,
     def interceptor(view):
         @wraps(view)
         def decorated(*args, **kwargs):
-            if current_app.config['EDITION'] == 'CLOUD':
-                tenant_id = current_user.current_tenant_id
-                billing_info = BillingService.get_info(tenant_id)
-                members = billing_info['members']
-                apps = billing_info['apps']
-                vector_space = billing_info['vector_space']
-                annotation_quota_limit = billing_info['annotation_quota_limit']
-
-                if resource == 'members' and 0 < members['limit'] <= members['size']:
+            features = FeatureService.get_features(current_user.current_tenant_id)
+
+            if features.billing.enabled:
+                members = features.members
+                apps = features.apps
+                vector_space = features.vector_space
+                annotation_quota_limit = features.annotation_quota_limit
+
+                if resource == 'members' and 0 < members.limit <= members.size:
                     abort(403, error_msg)
-                elif resource == 'apps' and 0 < apps['limit'] <= apps['size']:
+                elif resource == 'apps' and 0 < apps.limit <= apps.size:
                     abort(403, error_msg)
-                elif resource == 'vector_space' and 0 < vector_space['limit'] <= vector_space['size']:
+                elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size:
                     abort(403, error_msg)
-                elif resource == 'workspace_custom' and not billing_info['can_replace_logo']:
+                elif resource == 'workspace_custom' and not features.can_replace_logo:
                     abort(403, error_msg)
-                elif resource == 'annotation' and 0 < annotation_quota_limit['limit'] < annotation_quota_limit['size']:
+                elif resource == 'annotation' and 0 < annotation_quota_limit.limit < annotation_quota_limit.size:
                     abort(403, error_msg)
                 else:
                     return view(*args, **kwargs)

+ 10 - 11
api/controllers/service_api/wraps.py

@@ -11,8 +11,7 @@ from libs.login import _get_user
 from extensions.ext_database import db
 from models.account import Tenant, TenantAccountJoin, Account
 from models.model import ApiToken, App
-from services.billing_service import BillingService
-
+from services.feature_service import FeatureService
 
 def validate_app_token(view=None):
     def decorator(view):
@@ -46,19 +45,19 @@ def cloud_edition_billing_resource_check(resource: str,
                                          error_msg: str = "You have reached the limit of your subscription."):
     def interceptor(view):
         def decorated(*args, **kwargs):
-            if current_app.config['EDITION'] == 'CLOUD':
-                api_token = validate_and_get_api_token(api_token_type)
-                billing_info = BillingService.get_info(api_token.tenant_id)
+            api_token = validate_and_get_api_token(api_token_type)
+            features = FeatureService.get_features(api_token.tenant_id)
 
-                members = billing_info['members']
-                apps = billing_info['apps']
-                vector_space = billing_info['vector_space']
+            if features.billing.enabled:
+                members = features.members
+                apps = features.apps
+                vector_space = features.vector_space
 
-                if resource == 'members' and 0 < members['limit'] <= members['size']:
+                if resource == 'members' and 0 < members.limit <= members.size:
                     raise Unauthorized(error_msg)
-                elif resource == 'apps' and 0 < apps['limit'] <= apps['size']:
+                elif resource == 'apps' and 0 < apps.limit <= apps.size:
                     raise Unauthorized(error_msg)
-                elif resource == 'vector_space' and 0 < vector_space['limit'] <= vector_space['size']:
+                elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size:
                     raise Unauthorized(error_msg)
                 else:
                     return view(*args, **kwargs)

+ 2 - 7
api/controllers/web/site.py

@@ -9,7 +9,7 @@ from controllers.web import api
 from controllers.web.wraps import WebApiResource
 from extensions.ext_database import db
 from models.model import Site
-from services.billing_service import BillingService
+from services.feature_service import FeatureService
 
 
 class AppSiteApi(WebApiResource):
@@ -56,12 +56,7 @@ class AppSiteApi(WebApiResource):
         if not site:
             raise Forbidden()
 
-        edition = os.environ.get('EDITION')
-        can_replace_logo = False
-
-        if edition == 'CLOUD':
-            info = BillingService.get_info(app_model.tenant_id)
-            can_replace_logo = info['can_replace_logo']
+        can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
 
         return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo)
 

+ 71 - 0
api/services/feature_service.py

@@ -0,0 +1,71 @@
+from pydantic import BaseModel
+from flask import current_app
+
+from services.billing_service import BillingService
+
+
+class SubscriptionModel(BaseModel):
+    plan: str = 'sandbox'
+    interval: str = ''
+
+
+class BillingModel(BaseModel):
+    enabled: bool = False
+    subscription: SubscriptionModel = SubscriptionModel()
+
+
+class LimitationModel(BaseModel):
+    size: int = 0
+    limit: int = 0
+
+
+class FeatureModel(BaseModel):
+    billing: BillingModel = BillingModel()
+    members: LimitationModel = LimitationModel(size=0, limit=1)
+    apps: LimitationModel = LimitationModel(size=0, limit=10)
+    vector_space: LimitationModel = LimitationModel(size=0, limit=5)
+    annotation_quota_limit: LimitationModel = LimitationModel(size=0, limit=10)
+    docs_processing: str = 'standard'
+    can_replace_logo: bool = False
+
+
+class FeatureService:
+
+    @classmethod
+    def get_features(cls, tenant_id: str) -> FeatureModel:
+        features = FeatureModel()
+
+        cls._fulfill_params_from_env(features)
+
+        if current_app.config['BILLING_ENABLED']:
+            cls._fulfill_params_from_billing_api(features, tenant_id)
+
+        return features
+
+    @classmethod
+    def _fulfill_params_from_env(cls, features: FeatureModel):
+        features.can_replace_logo = current_app.config['CAN_REPLACE_LOGO']
+
+    @classmethod
+    def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
+        billing_info = BillingService.get_info(tenant_id)
+
+        features.billing.enabled = billing_info['enabled']
+        features.billing.subscription.plan = billing_info['subscription']['plan']
+        features.billing.subscription.interval = billing_info['subscription']['interval']
+
+        features.members.size = billing_info['members']['size']
+        features.members.limit = billing_info['members']['limit']
+
+        features.apps.size = billing_info['apps']['size']
+        features.apps.limit = billing_info['apps']['limit']
+
+        features.vector_space.size = billing_info['vector_space']['size']
+        features.vector_space.limit = billing_info['vector_space']['limit']
+
+        features.annotation_quota_limit.size = billing_info['annotation_quota_limit']['size']
+        features.annotation_quota_limit.limit = billing_info['annotation_quota_limit']['limit']
+
+        features.docs_processing = billing_info['docs_processing']
+        features.can_replace_logo = billing_info['can_replace_logo']
+

+ 4 - 6
api/services/workspace_service.py

@@ -4,7 +4,7 @@ from extensions.ext_database import db
 from models.account import Tenant, TenantAccountJoin, TenantAccountJoinRole
 from models.provider import Provider
 
-from services.billing_service import BillingService
+from services.feature_service import FeatureService
 from services.account_service import TenantService
 
 
@@ -32,12 +32,10 @@ class WorkspaceService:
         ).first()
         tenant_info['role'] = tenant_account_join.role
 
-        edition = current_app.config['EDITION']
-        if edition == 'CLOUD':
-            billing_info = BillingService.get_info(tenant_info['id'])
+        can_replace_logo = FeatureService.get_features(tenant_info['id']).can_replace_logo
 
-            if billing_info['can_replace_logo'] and TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]):
-                tenant_info['custom_config'] = tenant.custom_config_dict
+        if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]):
+            tenant_info['custom_config'] = tenant.custom_config_dict
 
         # Get providers
         providers = db.session.query(Provider).filter(

+ 6 - 3
web/app/components/billing/type.ts

@@ -32,9 +32,11 @@ export enum DocumentProcessingPriority {
 }
 
 export type CurrentPlanInfoBackend = {
-  enabled: boolean
-  subscription: {
-    plan: Plan
+  billing: {
+    enabled: boolean
+    subscription: {
+      plan: Plan
+    }
   }
   members: {
     size: number
@@ -53,6 +55,7 @@ export type CurrentPlanInfoBackend = {
     limit: number // total. 0 means unlimited
   }
   docs_processing: DocumentProcessingPriority
+  can_replace_logo: boolean
 }
 
 export type SubscriptionItem = {

+ 1 - 1
web/app/components/billing/utils/index.ts

@@ -10,7 +10,7 @@ const parseLimit = (limit: number) => {
 
 export const parseCurrentPlan = (data: CurrentPlanInfoBackend) => {
   return {
-    type: data.subscription.plan,
+    type: data.billing.subscription.plan,
     usage: {
       vectorSpace: data.vector_space.size,
       buildApps: data.apps?.size || 0,

+ 8 - 4
web/app/components/custom/custom-page/index.tsx

@@ -10,12 +10,16 @@ import { contactSalesUrl } from '@/app/components/billing/config'
 
 const CustomPage = () => {
   const { t } = useTranslation()
-  const { plan } = useProviderContext()
+  const { plan, enableBilling } = useProviderContext()
+
+  const showBillingTip = enableBilling && plan.type === Plan.sandbox
+  const showCustomAppHeaderBrand = enableBilling && plan.type === Plan.sandbox
+  const showContact = enableBilling && (plan.type === Plan.professional || plan.type === Plan.team)
 
   return (
     <div className='flex flex-col'>
       {
-        plan.type === Plan.sandbox && (
+        showBillingTip && (
           <GridMask canvasClassName='!rounded-xl'>
             <div className='flex justify-between mb-1 px-6 py-5 h-[88px] shadow-md rounded-xl border-[0.5px] border-gray-200'>
               <div className={`${s.textGradient} leading-[24px] text-base font-semibold`}>
@@ -29,7 +33,7 @@ const CustomPage = () => {
       }
       <CustomWebAppBrand />
       {
-        plan.type === Plan.sandbox && (
+        showCustomAppHeaderBrand && (
           <>
             <div className='my-2 h-[0.5px] bg-gray-100'></div>
             <CustomAppHeaderBrand />
@@ -37,7 +41,7 @@ const CustomPage = () => {
         )
       }
       {
-        (plan.type === Plan.professional || plan.type === Plan.team) && (
+        showContact && (
           <div className='absolute bottom-0 h-[50px] leading-[50px] text-xs text-gray-500'>
             {t('custom.customize.prefix')}
             <a className='text-[#155EEF]' href={contactSalesUrl} target='_blank'>{t('custom.customize.contactUs')}</a>

+ 9 - 8
web/app/components/custom/custom-web-app-brand/index.tsx

@@ -24,7 +24,7 @@ const ALLOW_FILE_EXTENSIONS = ['svg', 'png']
 const CustomWebAppBrand = () => {
   const { t } = useTranslation()
   const { notify } = useToastContext()
-  const { plan } = useProviderContext()
+  const { plan, enableBilling } = useProviderContext()
   const {
     currentWorkspace,
     mutateCurrentWorkspace,
@@ -32,10 +32,11 @@ const CustomWebAppBrand = () => {
   } = useAppContext()
   const [fileId, setFileId] = useState('')
   const [uploadProgress, setUploadProgress] = useState(0)
-  const isSandbox = plan.type === Plan.sandbox
+  const isSandbox = enableBilling && plan.type === Plan.sandbox
   const uploading = uploadProgress > 0 && uploadProgress < 100
   const webappLogo = currentWorkspace.custom_config?.replace_webapp_logo || ''
   const webappBrandRemoved = currentWorkspace.custom_config?.remove_webapp_brand
+  const uploadDisabled = isSandbox || webappBrandRemoved || !isCurrentWorkspaceManager
 
   const handleChange = (e: ChangeEvent<HTMLInputElement>) => {
     const file = e.target.files?.[0]
@@ -153,9 +154,9 @@ const CustomWebAppBrand = () => {
               <Button
                 className={`
                   relative mr-2 !h-8 !px-3 bg-white !text-[13px] 
-                  ${isSandbox ? 'opacity-40' : ''}
+                  ${uploadDisabled ? 'opacity-40' : ''}
                 `}
-                disabled={isSandbox || webappBrandRemoved || !isCurrentWorkspaceManager}
+                disabled={uploadDisabled}
               >
                 <ImagePlus className='mr-2 w-4 h-4' />
                 {
@@ -166,13 +167,13 @@ const CustomWebAppBrand = () => {
                 <input
                   className={`
                     absolute block inset-0 opacity-0 text-[0] w-full
-                    ${(isSandbox || webappBrandRemoved) ? 'cursor-not-allowed' : 'cursor-pointer'}
+                    ${uploadDisabled ? 'cursor-not-allowed' : 'cursor-pointer'}
                   `}
                   onClick={e => (e.target as HTMLInputElement).value = ''}
                   type='file'
                   accept={ALLOW_FILE_EXTENSIONS.map(ext => `.${ext}`).join(',')}
                   onChange={handleChange}
-                  disabled={isSandbox || webappBrandRemoved || !isCurrentWorkspaceManager}
+                  disabled={uploadDisabled}
                 />
               </Button>
             )
@@ -213,9 +214,9 @@ const CustomWebAppBrand = () => {
           <Button
             className={`
               !h-8 !px-3 bg-white !text-[13px] 
-              ${isSandbox ? 'opacity-40' : ''}
+              ${(uploadDisabled || (!webappLogo && !webappBrandRemoved)) ? 'opacity-40' : ''}
             `}
-            disabled={isSandbox || (!webappLogo && !webappBrandRemoved) || webappBrandRemoved || !isCurrentWorkspaceManager}
+            disabled={uploadDisabled || (!webappLogo && !webappBrandRemoved)}
             onClick={handleRestore}
           >
             {t('custom.restore')}

+ 2 - 3
web/app/components/header/account-setting/index.tsx

@@ -31,7 +31,6 @@ import { Colors } from '@/app/components/base/icons/src/vender/line/editor'
 import { Colors as ColorsSolid } from '@/app/components/base/icons/src/vender/solid/editor'
 import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
 import { useProviderContext } from '@/context/provider-context'
-import { IS_CE_EDITION } from '@/config'
 
 const iconClassName = `
   w-4 h-4 ml-3 mr-2
@@ -59,7 +58,7 @@ export default function AccountSetting({
 }: IAccountSettingProps) {
   const [activeMenu, setActiveMenu] = useState(activeTab)
   const { t } = useTranslation()
-  const { enableBilling } = useProviderContext()
+  const { enableBilling, enableReplaceWebAppLogo } = useProviderContext()
 
   const workplaceGroupItems = (() => {
     return [
@@ -101,7 +100,7 @@ export default function AccountSetting({
         activeIcon: <Webhooks className={iconClassName} />,
       },
       {
-        key: IS_CE_EDITION ? false : 'custom',
+        key: (enableReplaceWebAppLogo || enableBilling) ? 'custom' : false,
         name: t('custom.custom'),
         icon: <Colors className={iconClassName} />,
         activeIcon: <ColorsSolid className={iconClassName} />,

+ 6 - 1
web/context/provider-context.tsx

@@ -37,6 +37,7 @@ const ProviderContext = createContext<{
   }
   isFetchedPlan: boolean
   enableBilling: boolean
+  enableReplaceWebAppLogo: boolean
 }>({
       textGenerationModelList: [],
       embeddingsModelList: [],
@@ -72,6 +73,7 @@ const ProviderContext = createContext<{
       },
       isFetchedPlan: false,
       enableBilling: false,
+      enableReplaceWebAppLogo: false,
     })
 
 export const useProviderContext = () => useContext(ProviderContext)
@@ -119,11 +121,13 @@ export const ProviderContextProvider = ({
   const [plan, setPlan] = useState(defaultPlan)
   const [isFetchedPlan, setIsFetchedPlan] = useState(false)
   const [enableBilling, setEnableBilling] = useState(true)
+  const [enableReplaceWebAppLogo, setEnableReplaceWebAppLogo] = useState(false)
   useEffect(() => {
     (async () => {
       const data = await fetchCurrentPlanInfo()
-      const enabled = data.enabled
+      const enabled = data.billing.enabled
       setEnableBilling(enabled)
+      setEnableReplaceWebAppLogo(data.can_replace_logo)
       if (enabled) {
         setPlan(parseCurrentPlan(data))
         // setPlan(parseCurrentPlan({
@@ -160,6 +164,7 @@ export const ProviderContextProvider = ({
       plan,
       isFetchedPlan,
       enableBilling,
+      enableReplaceWebAppLogo,
     }}>
       {children}
     </ProviderContext.Provider>

+ 1 - 1
web/service/billing.ts

@@ -2,7 +2,7 @@ import { get } from './base'
 import type { CurrentPlanInfoBackend, SubscriptionUrlsBackend } from '@/app/components/billing/type'
 
 export const fetchCurrentPlanInfo = () => {
-  return get<Promise<CurrentPlanInfoBackend>>('/billing/info')
+  return get<Promise<CurrentPlanInfoBackend>>('/features')
 }
 
 export const fetchSubscriptionUrls = (plan: string, interval: string) => {