Pārlūkot izejas kodu

Feat/new saas billing (#12591)

Jyong 3 mēneši atpakaļ
vecāks
revīzija
d8f57bf899

+ 9 - 1
api/controllers/console/datasets/datasets.py

@@ -10,7 +10,12 @@ from controllers.console import api
 from controllers.console.apikey import api_key_fields, api_key_list
 from controllers.console.app.error import ProviderNotInitializeError
 from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
-from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
+from controllers.console.wraps import (
+    account_initialization_required,
+    cloud_edition_billing_rate_limit_check,
+    enterprise_license_required,
+    setup_required,
+)
 from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
 from core.indexing_runner import IndexingRunner
 from core.model_runtime.entities.model_entities import ModelType
@@ -93,6 +98,7 @@ class DatasetListApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def post(self):
         parser = reqparse.RequestParser()
         parser.add_argument(
@@ -207,6 +213,7 @@ class DatasetApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def patch(self, dataset_id):
         dataset_id_str = str(dataset_id)
         dataset = DatasetService.get_dataset(dataset_id_str)
@@ -310,6 +317,7 @@ class DatasetApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def delete(self, dataset_id):
         dataset_id_str = str(dataset_id)
 

+ 10 - 0
api/controllers/console/datasets/datasets_document.py

@@ -27,6 +27,7 @@ from controllers.console.datasets.error import (
 )
 from controllers.console.wraps import (
     account_initialization_required,
+    cloud_edition_billing_rate_limit_check,
     cloud_edition_billing_resource_check,
     setup_required,
 )
@@ -230,6 +231,7 @@ class DatasetDocumentListApi(Resource):
     @account_initialization_required
     @marshal_with(documents_and_batch_fields)
     @cloud_edition_billing_resource_check("vector_space")
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def post(self, dataset_id):
         dataset_id = str(dataset_id)
 
@@ -285,6 +287,7 @@ class DatasetDocumentListApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def delete(self, dataset_id):
         dataset_id = str(dataset_id)
         dataset = DatasetService.get_dataset(dataset_id)
@@ -308,6 +311,7 @@ class DatasetInitApi(Resource):
     @account_initialization_required
     @marshal_with(dataset_and_document_fields)
     @cloud_edition_billing_resource_check("vector_space")
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def post(self):
         # The role of the current user in the ta table must be admin, owner, or editor
         if not current_user.is_editor:
@@ -680,6 +684,7 @@ class DocumentProcessingApi(DocumentResource):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def patch(self, dataset_id, document_id, action):
         dataset_id = str(dataset_id)
         document_id = str(document_id)
@@ -716,6 +721,7 @@ class DocumentDeleteApi(DocumentResource):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def delete(self, dataset_id, document_id):
         dataset_id = str(dataset_id)
         document_id = str(document_id)
@@ -784,6 +790,7 @@ class DocumentStatusApi(DocumentResource):
     @login_required
     @account_initialization_required
     @cloud_edition_billing_resource_check("vector_space")
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def patch(self, dataset_id, action):
         dataset_id = str(dataset_id)
         dataset = DatasetService.get_dataset(dataset_id)
@@ -879,6 +886,7 @@ class DocumentPauseApi(DocumentResource):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def patch(self, dataset_id, document_id):
         """pause document."""
         dataset_id = str(dataset_id)
@@ -911,6 +919,7 @@ class DocumentRecoverApi(DocumentResource):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def patch(self, dataset_id, document_id):
         """recover document."""
         dataset_id = str(dataset_id)
@@ -940,6 +949,7 @@ class DocumentRetryApi(DocumentResource):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def post(self, dataset_id):
         """retry document."""
 

+ 11 - 0
api/controllers/console/datasets/datasets_segments.py

@@ -19,6 +19,7 @@ from controllers.console.datasets.error import (
 from controllers.console.wraps import (
     account_initialization_required,
     cloud_edition_billing_knowledge_limit_check,
+    cloud_edition_billing_rate_limit_check,
     cloud_edition_billing_resource_check,
     setup_required,
 )
@@ -106,6 +107,7 @@ class DatasetDocumentSegmentListApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def delete(self, dataset_id, document_id):
         # check dataset
         dataset_id = str(dataset_id)
@@ -137,6 +139,7 @@ class DatasetDocumentSegmentApi(Resource):
     @login_required
     @account_initialization_required
     @cloud_edition_billing_resource_check("vector_space")
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def patch(self, dataset_id, document_id, action):
         dataset_id = str(dataset_id)
         dataset = DatasetService.get_dataset(dataset_id)
@@ -192,6 +195,7 @@ class DatasetDocumentSegmentAddApi(Resource):
     @account_initialization_required
     @cloud_edition_billing_resource_check("vector_space")
     @cloud_edition_billing_knowledge_limit_check("add_segment")
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def post(self, dataset_id, document_id):
         # check dataset
         dataset_id = str(dataset_id)
@@ -242,6 +246,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
     @login_required
     @account_initialization_required
     @cloud_edition_billing_resource_check("vector_space")
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def patch(self, dataset_id, document_id, segment_id):
         # check dataset
         dataset_id = str(dataset_id)
@@ -302,6 +307,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def delete(self, dataset_id, document_id, segment_id):
         # check dataset
         dataset_id = str(dataset_id)
@@ -339,6 +345,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
     @account_initialization_required
     @cloud_edition_billing_resource_check("vector_space")
     @cloud_edition_billing_knowledge_limit_check("add_segment")
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def post(self, dataset_id, document_id):
         # check dataset
         dataset_id = str(dataset_id)
@@ -405,6 +412,7 @@ class ChildChunkAddApi(Resource):
     @account_initialization_required
     @cloud_edition_billing_resource_check("vector_space")
     @cloud_edition_billing_knowledge_limit_check("add_segment")
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def post(self, dataset_id, document_id, segment_id):
         # check dataset
         dataset_id = str(dataset_id)
@@ -503,6 +511,7 @@ class ChildChunkAddApi(Resource):
     @login_required
     @account_initialization_required
     @cloud_edition_billing_resource_check("vector_space")
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def patch(self, dataset_id, document_id, segment_id):
         # check dataset
         dataset_id = str(dataset_id)
@@ -546,6 +555,7 @@ class ChildChunkUpdateApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
         # check dataset
         dataset_id = str(dataset_id)
@@ -590,6 +600,7 @@ class ChildChunkUpdateApi(Resource):
     @login_required
     @account_initialization_required
     @cloud_edition_billing_resource_check("vector_space")
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
         # check dataset
         dataset_id = str(dataset_id)

+ 6 - 1
api/controllers/console/datasets/hit_testing.py

@@ -2,7 +2,11 @@ from flask_restful import Resource  # type: ignore
 
 from controllers.console import api
 from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
-from controllers.console.wraps import account_initialization_required, setup_required
+from controllers.console.wraps import (
+    account_initialization_required,
+    cloud_edition_billing_rate_limit_check,
+    setup_required,
+)
 from libs.login import login_required
 
 
@@ -10,6 +14,7 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def post(self, dataset_id):
         dataset_id_str = str(dataset_id)
 

+ 32 - 1
api/controllers/console/wraps.py

@@ -1,5 +1,6 @@
 import json
 import os
+import time
 from functools import wraps
 
 from flask import abort, request
@@ -7,6 +8,7 @@ from flask_login import current_user  # type: ignore
 
 from configs import dify_config
 from controllers.console.workspace.error import AccountNotInitializedError
+from extensions.ext_redis import redis_client
 from models.model import DifySetup
 from services.feature_service import FeatureService, LicenseStatus
 from services.operation_service import OperationService
@@ -66,7 +68,9 @@ def cloud_edition_billing_resource_check(resource: str):
                 elif resource == "apps" and 0 < apps.limit <= apps.size:
                     abort(403, "The number of apps has reached the limit of your subscription.")
                 elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
-                    abort(403, "The capacity of the vector space has reached the limit of your subscription.")
+                    abort(
+                        403, "The capacity of the knowledge storage space has reached the limit of your subscription."
+                    )
                 elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
                     # The api of file upload is used in the multiple places,
                     # so we need to check the source of the request from datasets
@@ -111,6 +115,33 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
     return interceptor
 
 
+def cloud_edition_billing_rate_limit_check(resource: str):
+    def interceptor(view):
+        @wraps(view)
+        def decorated(*args, **kwargs):
+            if resource == "knowledge":
+                knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
+                if knowledge_rate_limit.enabled:
+                    current_time = int(time.time() * 1000)
+                    key = f"rate_limit_{current_user.current_tenant_id}"
+
+                    redis_client.zadd(key, {current_time: current_time})
+
+                    redis_client.zremrangebyscore(key, 0, current_time - 60000)
+
+                    request_count = redis_client.zcard(key)
+
+                    if request_count > knowledge_rate_limit.limit:
+                        abort(
+                            403, "Sorry, you have reached the knowledge base request rate limit of your subscription."
+                        )
+            return view(*args, **kwargs)
+
+        return decorated
+
+    return interceptor
+
+
 def cloud_utm_record(view):
     @wraps(view)
     def decorated(*args, **kwargs):

+ 31 - 0
api/controllers/service_api/wraps.py

@@ -1,3 +1,4 @@
+import time
 from collections.abc import Callable
 from datetime import UTC, datetime, timedelta
 from enum import Enum
@@ -13,6 +14,7 @@ from sqlalchemy.orm import Session
 from werkzeug.exceptions import Forbidden, Unauthorized
 
 from extensions.ext_database import db
+from extensions.ext_redis import redis_client
 from libs.login import _get_user
 from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
 from models.model import ApiToken, App, EndUser
@@ -139,6 +141,35 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
     return interceptor
 
 
+def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
+    def interceptor(view):
+        @wraps(view)
+        def decorated(*args, **kwargs):
+            api_token = validate_and_get_api_token(api_token_type)
+
+            if resource == "knowledge":
+                knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(api_token.tenant_id)
+                if knowledge_rate_limit.enabled:
+                    current_time = int(time.time() * 1000)
+                    key = f"rate_limit_{api_token.tenant_id}"
+
+                    redis_client.zadd(key, {current_time: current_time})
+
+                    redis_client.zremrangebyscore(key, 0, current_time - 60000)
+
+                    request_count = redis_client.zcard(key)
+
+                    if request_count > knowledge_rate_limit.limit:
+                        raise Forbidden(
+                            "Sorry, you have reached the knowledge base request rate limit of your subscription."
+                        )
+            return view(*args, **kwargs)
+
+        return decorated
+
+    return interceptor
+
+
 def validate_dataset_token(view=None):
     def decorator(view):
         @wraps(view)

+ 20 - 0
api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py

@@ -1,4 +1,5 @@
 import logging
+import time
 from collections.abc import Mapping, Sequence
 from typing import Any, cast
 
@@ -19,8 +20,10 @@ from core.workflow.entities.node_entities import NodeRunResult
 from core.workflow.nodes.base import BaseNode
 from core.workflow.nodes.enums import NodeType
 from extensions.ext_database import db
+from extensions.ext_redis import redis_client
 from models.dataset import Dataset, Document
 from models.workflow import WorkflowNodeExecutionStatus
+from services.feature_service import FeatureService
 
 from .entities import KnowledgeRetrievalNodeData
 from .exc import (
@@ -61,6 +64,23 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
             return NodeRunResult(
                 status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required."
             )
+        # check rate limit
+        if self.tenant_id:
+            knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
+            if knowledge_rate_limit.enabled:
+                current_time = int(time.time() * 1000)
+                key = f"rate_limit_{self.tenant_id}"
+                redis_client.zadd(key, {current_time: current_time})
+                redis_client.zremrangebyscore(key, 0, current_time - 60000)
+                request_count = redis_client.zcard(key)
+                if request_count > knowledge_rate_limit.limit:
+                    return NodeRunResult(
+                        status=WorkflowNodeExecutionStatus.FAILED,
+                        inputs=variables,
+                        error="Sorry, you have reached the knowledge base request rate limit of your subscription.",
+                        error_type="RateLimitExceeded",
+                    )
+
         # retrieve knowledge
         try:
             results = self._fetch_dataset_retriever(node_data=self.node_data, query=query)

+ 8 - 0
api/services/billing_service.py

@@ -19,6 +19,14 @@ class BillingService:
         billing_info = cls._send_request("GET", "/subscription/info", params=params)
         return billing_info
 
+    @classmethod
+    def get_knowledge_rate_limit(cls, tenant_id: str):
+        params = {"tenant_id": tenant_id}
+
+        knowledge_rate_limit = cls._send_request("GET", "/subscription/knowledge-rate-limit", params=params)
+
+        return knowledge_rate_limit.get("limit", 10)
+
     @classmethod
     def get_subscription(cls, plan: str, interval: str, prefilled_email: str = "", tenant_id: str = ""):
         params = {"plan": plan, "interval": interval, "prefilled_email": prefilled_email, "tenant_id": tenant_id}

+ 17 - 0
api/services/feature_service.py

@@ -41,6 +41,7 @@ class FeatureModel(BaseModel):
     members: LimitationModel = LimitationModel(size=0, limit=1)
     apps: LimitationModel = LimitationModel(size=0, limit=10)
     vector_space: LimitationModel = LimitationModel(size=0, limit=5)
+    knowledge_rate_limit: int = 10
     annotation_quota_limit: LimitationModel = LimitationModel(size=0, limit=10)
     documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50)
     docs_processing: str = "standard"
@@ -52,6 +53,11 @@ class FeatureModel(BaseModel):
     model_config = ConfigDict(protected_namespaces=())
 
 
+class KnowledgeRateLimitModel(BaseModel):
+    enabled: bool = False
+    limit: int = 10
+
+
 class SystemFeatureModel(BaseModel):
     sso_enforced_for_signin: bool = False
     sso_enforced_for_signin_protocol: str = ""
@@ -79,6 +85,14 @@ class FeatureService:
 
         return features
 
+    @classmethod
+    def get_knowledge_rate_limit(cls, tenant_id: str):
+        knowledge_rate_limit = KnowledgeRateLimitModel()
+        if dify_config.BILLING_ENABLED and tenant_id:
+            knowledge_rate_limit.enabled = True
+            knowledge_rate_limit.limit = BillingService.get_knowledge_rate_limit(tenant_id)
+        return knowledge_rate_limit
+
     @classmethod
     def get_system_features(cls) -> SystemFeatureModel:
         system_features = SystemFeatureModel()
@@ -144,6 +158,9 @@ class FeatureService:
         if "model_load_balancing_enabled" in billing_info:
             features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"]
 
+        if "knowledge_rate_limit" in billing_info:
+            features.knowledge_rate_limit = billing_info["knowledge_rate_limit"]["limit"]
+
     @classmethod
     def _fulfill_params_from_enterprise(cls, features):
         enterprise_info = EnterpriseService.get_info()