Prechádzať zdrojové kódy

add segment function billing check for SAAS env (#3082)

Co-authored-by: jyong <jyong@dify.ai>
Jyong 1 rok pred
rodič
commit
e12a0c154c

+ 7 - 1
api/controllers/console/datasets/datasets_segments.py

@@ -12,7 +12,11 @@ from controllers.console import api
 from controllers.console.app.error import ProviderNotInitializeError
 from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
 from controllers.console.setup import setup_required
-from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
+from controllers.console.wraps import (
+    account_initialization_required,
+    cloud_edition_billing_knowledge_limit_check,
+    cloud_edition_billing_resource_check,
+)
 from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
@@ -207,6 +211,7 @@ class DatasetDocumentSegmentAddApi(Resource):
     @login_required
     @account_initialization_required
     @cloud_edition_billing_resource_check('vector_space')
+    @cloud_edition_billing_knowledge_limit_check('add_segment')
     def post(self, dataset_id, document_id):
         # check dataset
         dataset_id = str(dataset_id)
@@ -357,6 +362,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
     @login_required
     @account_initialization_required
     @cloud_edition_billing_resource_check('vector_space')
+    @cloud_edition_billing_knowledge_limit_check('add_segment')
     def post(self, dataset_id, document_id):
         # check dataset
         dataset_id = str(dataset_id)

+ 23 - 2
api/controllers/console/wraps.py

@@ -51,14 +51,12 @@ def cloud_edition_billing_resource_check(resource: str,
         @wraps(view)
         def decorated(*args, **kwargs):
             features = FeatureService.get_features(current_user.current_tenant_id)
-
             if features.billing.enabled:
                 members = features.members
                 apps = features.apps
                 vector_space = features.vector_space
                 documents_upload_quota = features.documents_upload_quota
                 annotation_quota_limit = features.annotation_quota_limit
-
                 if resource == 'members' and 0 < members.limit <= members.size:
                     abort(403, error_msg)
                 elif resource == 'apps' and 0 < apps.limit <= apps.size:
@@ -80,7 +78,29 @@ def cloud_edition_billing_resource_check(resource: str,
                     return view(*args, **kwargs)
 
             return view(*args, **kwargs)
+
         return decorated
+
+    return interceptor
+
+
+def cloud_edition_billing_knowledge_limit_check(resource: str,
+                                                error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."):
+    def interceptor(view):
+        @wraps(view)
+        def decorated(*args, **kwargs):
+            features = FeatureService.get_features(current_user.current_tenant_id)
+            if features.billing.enabled:
+                if resource == 'add_segment':
+                    if features.billing.subscription.plan == 'sandbox':
+                        abort(403, error_msg)
+                else:
+                    return view(*args, **kwargs)
+
+            return view(*args, **kwargs)
+
+        return decorated
+
     return interceptor
 
 
@@ -99,4 +119,5 @@ def cloud_utm_record(view):
         except Exception as e:
             pass
         return view(*args, **kwargs)
+
     return decorated

+ 6 - 1
api/controllers/service_api/dataset/segment.py

@@ -4,7 +4,11 @@ from werkzeug.exceptions import NotFound
 
 from controllers.service_api import api
 from controllers.service_api.app.error import ProviderNotInitializeError
-from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
+from controllers.service_api.wraps import (
+    DatasetApiResource,
+    cloud_edition_billing_knowledge_limit_check,
+    cloud_edition_billing_resource_check,
+)
 from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
@@ -18,6 +22,7 @@ class SegmentApi(DatasetApiResource):
     """Resource for segments."""
 
     @cloud_edition_billing_resource_check('vector_space', 'dataset')
+    @cloud_edition_billing_knowledge_limit_check('add_segment', 'dataset')
     def post(self, tenant_id, dataset_id, document_id):
         """Create single segment."""
         # check dataset

+ 26 - 5
api/controllers/service_api/wraps.py

@@ -8,7 +8,7 @@ from flask import current_app, request
 from flask_login import user_logged_in
 from flask_restful import Resource
 from pydantic import BaseModel
-from werkzeug.exceptions import NotFound, Unauthorized
+from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
 
 from extensions.ext_database import db
 from libs.login import _get_user
@@ -92,13 +92,13 @@ def cloud_edition_billing_resource_check(resource: str,
                 documents_upload_quota = features.documents_upload_quota
 
                 if resource == 'members' and 0 < members.limit <= members.size:
-                    raise Unauthorized(error_msg)
+                    raise Forbidden(error_msg)
                 elif resource == 'apps' and 0 < apps.limit <= apps.size:
-                    raise Unauthorized(error_msg)
+                    raise Forbidden(error_msg)
                 elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size:
-                    raise Unauthorized(error_msg)
+                    raise Forbidden(error_msg)
                 elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
-                    raise Unauthorized(error_msg)
+                    raise Forbidden(error_msg)
                 else:
                     return view(*args, **kwargs)
 
@@ -107,6 +107,27 @@ def cloud_edition_billing_resource_check(resource: str,
     return interceptor
 
 
+def cloud_edition_billing_knowledge_limit_check(resource: str,
+                                                api_token_type: str,
+                                                error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."):
+    def interceptor(view):
+        @wraps(view)
+        def decorated(*args, **kwargs):
+            api_token = validate_and_get_api_token(api_token_type)
+            features = FeatureService.get_features(api_token.tenant_id)
+            if features.billing.enabled:
+                if resource == 'add_segment':
+                    if features.billing.subscription.plan == 'sandbox':
+                        raise Forbidden(error_msg)
+                else:
+                    return view(*args, **kwargs)
+
+            return view(*args, **kwargs)
+
+        return decorated
+
+    return interceptor
+
 def validate_dataset_token(view=None):
     def decorator(view):
         @wraps(view)