Ver código fonte

feat: billing enhancement 20231204 (#1691)

Co-authored-by: jyong <jyong@dify.ai>
Garfield Dai 1 ano atrás
pai
commit
7b8a10f3ea

+ 0 - 2
api/config.py

@@ -54,7 +54,6 @@ DEFAULTS = {
     'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100,
     'HOSTED_MODERATION_ENABLED': 'False',
     'HOSTED_MODERATION_PROVIDERS': '',
-    'TENANT_DOCUMENT_COUNT': 100,
     'CLEAN_DAY_SETTING': 30,
     'UPLOAD_FILE_SIZE_LIMIT': 15,
     'UPLOAD_FILE_BATCH_LIMIT': 5,
@@ -240,7 +239,6 @@ class Config:
         self.MULTIMODAL_SEND_IMAGE_FORMAT = get_env('MULTIMODAL_SEND_IMAGE_FORMAT')
 
         # Dataset Configurations.
-        self.TENANT_DOCUMENT_COUNT = get_env('TENANT_DOCUMENT_COUNT')
         self.CLEAN_DAY_SETTING = get_env('CLEAN_DAY_SETTING')
 
         # File upload Configurations.

+ 2 - 1
api/controllers/console/app/app.py

@@ -12,7 +12,7 @@ from constants.model_template import model_templates, demo_model_templates
 from controllers.console import api
 from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
 from controllers.console.setup import setup_required
-from controllers.console.wraps import account_initialization_required
+from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
 from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
 from core.model_providers.model_factory import ModelFactory
 from core.model_providers.model_provider_factory import ModelProviderFactory
@@ -57,6 +57,7 @@ class AppListApi(Resource):
     @login_required
     @account_initialization_required
     @marshal_with(app_detail_fields)
+    @cloud_edition_billing_resource_check('apps')
     def post(self):
         """Create app"""
         parser = reqparse.RequestParser()

+ 4 - 25
api/controllers/console/datasets/datasets_document.py

@@ -16,7 +16,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
 from controllers.console.datasets.error import DocumentAlreadyFinishedError, InvalidActionError, DocumentIndexingError, \
     InvalidMetadataError, ArchivedDocumentImmutableError
 from controllers.console.setup import setup_required
-from controllers.console.wraps import account_initialization_required
+from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
 from core.indexing_runner import IndexingRunner
 from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
     LLMBadRequestError
@@ -194,6 +194,7 @@ class DatasetDocumentListApi(Resource):
     @login_required
     @account_initialization_required
     @marshal_with(documents_and_batch_fields)
+    @cloud_edition_billing_resource_check('vector_space')
     def post(self, dataset_id):
         dataset_id = str(dataset_id)
 
@@ -252,6 +253,7 @@ class DatasetInitApi(Resource):
     @login_required
     @account_initialization_required
     @marshal_with(dataset_and_document_fields)
+    @cloud_edition_billing_resource_check('vector_space')
     def post(self):
         # The role of the current user in the ta table must be admin or owner
         if current_user.current_tenant.current_role not in ['admin', 'owner']:
@@ -693,6 +695,7 @@ class DocumentStatusApi(DocumentResource):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_resource_check('vector_space')
     def patch(self, dataset_id, document_id, action):
         dataset_id = str(dataset_id)
         document_id = str(document_id)
@@ -770,14 +773,6 @@ class DocumentStatusApi(DocumentResource):
             if not document.archived:
                 raise InvalidActionError('Document is not archived.')
 
-            # check document limit
-            if current_app.config['EDITION'] == 'CLOUD':
-                documents_count = DocumentService.get_tenant_documents_count()
-                total_count = documents_count + 1
-                tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
-                if total_count > tenant_document_count:
-                    raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
-
             document.archived = False
             document.archived_at = None
             document.archived_by = None
@@ -856,21 +851,6 @@ class DocumentRecoverApi(DocumentResource):
         return {'result': 'success'}, 204
 
 
-class DocumentLimitApi(DocumentResource):
-    @setup_required
-    @login_required
-    @account_initialization_required
-    def get(self):
-        """get document limit"""
-        documents_count = DocumentService.get_tenant_documents_count()
-        tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
-
-        return {
-            'documents_count': documents_count,
-            'documents_limit': tenant_document_count
-                }, 200
-
-
 api.add_resource(GetProcessRuleApi, '/datasets/process-rule')
 api.add_resource(DatasetDocumentListApi,
                  '/datasets/<uuid:dataset_id>/documents')
@@ -896,4 +876,3 @@ api.add_resource(DocumentStatusApi,
                  '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/status/<string:action>')
 api.add_resource(DocumentPauseApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause')
 api.add_resource(DocumentRecoverApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume')
-api.add_resource(DocumentLimitApi, '/datasets/limit')

+ 5 - 1
api/controllers/console/datasets/datasets_segments.py

@@ -11,7 +11,7 @@ from controllers.console import api
 from controllers.console.app.error import ProviderNotInitializeError
 from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
 from controllers.console.setup import setup_required
-from controllers.console.wraps import account_initialization_required
+from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
 from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
 from core.model_providers.model_factory import ModelFactory
 from libs.login import login_required
@@ -114,6 +114,7 @@ class DatasetDocumentSegmentApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_resource_check('vector_space')
     def patch(self, dataset_id, segment_id, action):
         dataset_id = str(dataset_id)
         dataset = DatasetService.get_dataset(dataset_id)
@@ -200,6 +201,7 @@ class DatasetDocumentSegmentAddApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_resource_check('vector_space')
     def post(self, dataset_id, document_id):
         # check dataset
         dataset_id = str(dataset_id)
@@ -250,6 +252,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_resource_check('vector_space')
     def patch(self, dataset_id, document_id, segment_id):
         # check dataset
         dataset_id = str(dataset_id)
@@ -344,6 +347,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_resource_check('vector_space')
     def post(self, dataset_id, document_id):
         # check dataset
         dataset_id = str(dataset_id)

+ 2 - 0
api/controllers/console/explore/installed_app.py

@@ -14,6 +14,7 @@ from extensions.ext_database import db
 from fields.installed_app_fields import installed_app_list_fields
 from models.model import App, InstalledApp, RecommendedApp
 from services.account_service import TenantService
+from controllers.console.wraps import cloud_edition_billing_resource_check
 
 
 class InstalledAppsListApi(Resource):
@@ -47,6 +48,7 @@ class InstalledAppsListApi(Resource):
 
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_resource_check('apps')
     def post(self):
         parser = reqparse.RequestParser()
         parser.add_argument('app_id', type=str, required=True, help='Invalid app_id')

+ 2 - 1
api/controllers/console/workspace/members.py

@@ -7,7 +7,7 @@ from flask_restful import Resource, reqparse, marshal_with, abort, fields, marsh
 import services
 from controllers.console import api
 from controllers.console.setup import setup_required
-from controllers.console.wraps import account_initialization_required
+from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
 from libs.helper import TimestampField
 from extensions.ext_database import db
 from models.account import Account, TenantAccountJoin
@@ -47,6 +47,7 @@ class MemberInviteEmailApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_resource_check('members')
     def post(self):
         parser = reqparse.RequestParser()
         parser.add_argument('emails', type=str, required=True, location='json', action='append')

+ 28 - 0
api/controllers/console/wraps.py

@@ -5,6 +5,7 @@ from flask import current_app, abort
 from flask_login import current_user
 
 from controllers.console.workspace.error import AccountNotInitializedError
+from services.billing_service import BillingService
 
 
 def account_initialization_required(view):
@@ -41,3 +42,30 @@ def only_edition_self_hosted(view):
         return view(*args, **kwargs)
 
     return decorated
+
+
+def cloud_edition_billing_resource_check(resource: str,
+                                         error_msg: str = "You have reached the limit of your subscription."):
+    def interceptor(view):
+        @wraps(view)
+        def decorated(*args, **kwargs):
+            if current_app.config['EDITION'] == 'CLOUD':
+                tenant_id = current_user.current_tenant_id
+                billing_info = BillingService.get_info(tenant_id)
+                members = billing_info['members']
+                apps = billing_info['apps']
+                vector_space = billing_info['vector_space']
+
+                if resource == 'members' and 0 < members['limit'] <= members['size']:
+                    abort(403, error_msg)
+                elif resource == 'apps' and 0 < apps['limit'] <= apps['size']:
+                    abort(403, error_msg)
+                elif resource == 'vector_space' and 0 < vector_space['limit'] <= vector_space['size']:
+                    abort(403, error_msg)
+                else:
+                    return view(*args, **kwargs)
+
+            return view(*args, **kwargs)
+        return decorated
+    return interceptor
+

+ 5 - 1
api/controllers/service_api/dataset/document.py

@@ -11,7 +11,7 @@ from controllers.service_api import api
 from controllers.service_api.app.error import ProviderNotInitializeError
 from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \
     NoFileUploadedError, TooManyFilesError
-from controllers.service_api.wraps import DatasetApiResource
+from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
 from libs.login import current_user
 from core.model_providers.error import ProviderTokenNotInitError
 from extensions.ext_database import db
@@ -24,6 +24,7 @@ from services.file_service import FileService
 class DocumentAddByTextApi(DatasetApiResource):
     """Resource for documents."""
 
+    @cloud_edition_billing_resource_check('vector_space', 'dataset')
     def post(self, tenant_id, dataset_id):
         """Create document by text."""
         parser = reqparse.RequestParser()
@@ -88,6 +89,7 @@ class DocumentAddByTextApi(DatasetApiResource):
 class DocumentUpdateByTextApi(DatasetApiResource):
     """Resource for update documents."""
 
+    @cloud_edition_billing_resource_check('vector_space', 'dataset')
     def post(self, tenant_id, dataset_id, document_id):
         """Update document by text."""
         parser = reqparse.RequestParser()
@@ -147,6 +149,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
 
 class DocumentAddByFileApi(DatasetApiResource):
     """Resource for documents."""
+    @cloud_edition_billing_resource_check('vector_space', 'dataset')
     def post(self, tenant_id, dataset_id):
         """Create document by upload file."""
         args = {}
@@ -212,6 +215,7 @@ class DocumentAddByFileApi(DatasetApiResource):
 class DocumentUpdateByFileApi(DatasetApiResource):
     """Resource for update documents."""
 
+    @cloud_edition_billing_resource_check('vector_space', 'dataset')
     def post(self, tenant_id, dataset_id, document_id):
         """Update document by upload file."""
         args = {}

+ 4 - 1
api/controllers/service_api/dataset/segment.py

@@ -3,7 +3,7 @@ from flask_restful import reqparse, marshal
 from werkzeug.exceptions import NotFound
 from controllers.service_api import api
 from controllers.service_api.app.error import ProviderNotInitializeError
-from controllers.service_api.wraps import DatasetApiResource
+from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
 from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
 from core.model_providers.model_factory import ModelFactory
 from extensions.ext_database import db
@@ -14,6 +14,8 @@ from services.dataset_service import DatasetService, DocumentService, SegmentSer
 
 class SegmentApi(DatasetApiResource):
     """Resource for segments."""
+
+    @cloud_edition_billing_resource_check('vector_space', 'dataset')
     def post(self, tenant_id, dataset_id, document_id):
         """Create single segment."""
         # check dataset
@@ -144,6 +146,7 @@ class DatasetSegmentApi(DatasetApiResource):
         SegmentService.delete_segment(segment, document, dataset)
         return {'result': 'success'}, 200
 
+    @cloud_edition_billing_resource_check('vector_space', 'dataset')
     def post(self, tenant_id, dataset_id, document_id, segment_id):
         # check dataset
         dataset_id = str(dataset_id)

+ 28 - 0
api/controllers/service_api/wraps.py

@@ -11,6 +11,7 @@ from libs.login import _get_user
 from extensions.ext_database import db
 from models.account import Tenant, TenantAccountJoin, Account
 from models.model import ApiToken, App
+from services.billing_service import BillingService
 
 
 def validate_app_token(view=None):
@@ -40,6 +41,33 @@ def validate_app_token(view=None):
     return decorator
 
 
+def cloud_edition_billing_resource_check(resource: str,
+                                         api_token_type: str,
+                                         error_msg: str = "You have reached the limit of your subscription."):
+    def interceptor(view):
+        def decorated(*args, **kwargs):
+            if current_app.config['EDITION'] == 'CLOUD':
+                api_token = validate_and_get_api_token(api_token_type)
+                billing_info = BillingService.get_info(api_token.tenant_id)
+
+                members = billing_info['members']
+                apps = billing_info['apps']
+                vector_space = billing_info['vector_space']
+
+                if resource == 'members' and 0 < members['limit'] <= members['size']:
+                    raise Unauthorized(error_msg)
+                elif resource == 'apps' and 0 < apps['limit'] <= apps['size']:
+                    raise Unauthorized(error_msg)
+                elif resource == 'vector_space' and 0 < vector_space['limit'] <= vector_space['size']:
+                    raise Unauthorized(error_msg)
+                else:
+                    return view(*args, **kwargs)
+
+            return view(*args, **kwargs)
+        return decorated
+    return interceptor
+
+
 def validate_dataset_token(view=None):
     def decorator(view):
         @wraps(view)

+ 5 - 7
api/services/billing_service.py

@@ -1,8 +1,6 @@
 import os
 import requests
 
-from services.dataset_service import DatasetService
-
 
 class BillingService:
     base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL')
@@ -14,14 +12,14 @@ class BillingService:
 
         billing_info = cls._send_request('GET', '/info', params=params)
 
-        vector_size = DatasetService.get_tenant_datasets_usage(tenant_id)
-        # Convert bytes to MB
-        billing_info['vector_space']['size'] = int(vector_size / 1024 / 1024)
-
         return billing_info
 
     @classmethod
-    def get_subscription(cls, plan: str, interval: str, prefilled_email: str = '', user_name: str = '', tenant_id: str = ''):
+    def get_subscription(cls, plan: str,
+                         interval: str,
+                         prefilled_email: str = '',
+                         user_name: str = '',
+                         tenant_id: str = ''):
         params = {
             'plan': plan,
             'interval': interval,

+ 1 - 42
api/services/dataset_service.py

@@ -227,36 +227,6 @@ class DatasetService:
         return AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) \
             .order_by(db.desc(AppDatasetJoin.created_at)).all()
 
-    @staticmethod
-    def get_tenant_datasets_usage(tenant_id):
-        # get the high_quality datasets
-        dataset_ids = db.session.query(Dataset.id).filter(Dataset.indexing_technique == 'high_quality',
-                                                          Dataset.tenant_id == tenant_id).all()
-        if not dataset_ids:
-            return 0
-        dataset_ids = [result[0] for result in dataset_ids]
-        document_ids = db.session.query(Document.id).filter(Document.dataset_id.in_(dataset_ids),
-                                                            Document.tenant_id == tenant_id,
-                                                            Document.completed_at.isnot(None),
-                                                            Document.enabled == True,
-                                                            Document.archived == False
-                                                            ).all()
-        if not document_ids:
-            return 0
-        document_ids = [result[0] for result in document_ids]
-        document_segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id.in_(document_ids),
-                                                                     DocumentSegment.tenant_id == tenant_id,
-                                                                     DocumentSegment.completed_at.isnot(None),
-                                                                     DocumentSegment.enabled == True,
-                                                                     ).all()
-        if not document_segments:
-            return 0
-
-        total_words_size = sum(document_segment.word_count * 3 for document_segment in document_segments)
-        total_vector_size = 1536 * 4 * len(document_segments)
-
-        return total_words_size + total_vector_size
-
 
 class DocumentService:
     DEFAULT_RULES = {
@@ -480,11 +450,6 @@ class DocumentService:
                     notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
                     for notion_info in notion_info_list:
                         count = count + len(notion_info['pages'])
-                documents_count = DocumentService.get_tenant_documents_count()
-                total_count = documents_count + count
-                tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
-                if total_count > tenant_document_count:
-                    raise ValueError(f"over document limit {tenant_document_count}.")
         # if dataset is empty, update dataset data_source_type
         if not dataset.data_source_type:
             dataset.data_source_type = document_data["data_source"]["type"]
@@ -770,13 +735,7 @@ class DocumentService:
             notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
             for notion_info in notion_info_list:
                 count = count + len(notion_info['pages'])
-        # check document limit
-        if current_app.config['EDITION'] == 'CLOUD':
-            documents_count = DocumentService.get_tenant_documents_count()
-            total_count = documents_count + count
-            tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
-            if total_count > tenant_document_count:
-                raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
+
         embedding_model = None
         dataset_collection_binding_id = None
         retrieval_model = None