瀏覽代碼

feat: app rate limit (#5844)

Co-authored-by: liuzhenghua-jk <liuzhenghua-jk@360shuke.com>
Co-authored-by: takatost <takatost@gmail.com>
liuzhenghua 9 月之前
父節點
當前提交
9622fbb62f

+ 1 - 1
api/.env.example

@@ -247,4 +247,4 @@ WORKFLOW_CALL_MAX_DEPTH=5
 
 # App configuration
 APP_MAX_EXECUTION_TIME=1200
-
+APP_MAX_ACTIVE_REQUESTS=0

+ 4 - 0
api/configs/feature/__init__.py

@@ -31,6 +31,10 @@ class AppExecutionConfig(BaseSettings):
         description='execution timeout in seconds for app execution',
         default=1200,
     )
+    APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field(
+        description='max active request per app, 0 means unlimited',
+        default=0,
+    )
 
 
 class CodeExecutionSandboxConfig(BaseSettings):

+ 1 - 0
api/controllers/console/app/app.py

@@ -134,6 +134,7 @@ class AppApi(Resource):
         parser.add_argument('description', type=str, location='json')
         parser.add_argument('icon', type=str, location='json')
         parser.add_argument('icon_background', type=str, location='json')
+        parser.add_argument('max_active_requests', type=int, location='json')
         args = parser.parse_args()
 
         app_service = AppService()

+ 8 - 3
api/controllers/console/app/completion.py

@@ -19,7 +19,12 @@ from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
 from core.app.apps.base_app_queue_manager import AppQueueManager
 from core.app.entities.app_invoke_entities import InvokeFrom
-from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
+from core.errors.error import (
+    AppInvokeQuotaExceededError,
+    ModelCurrentlyNotSupportError,
+    ProviderTokenNotInitError,
+    QuotaExceededError,
+)
 from core.model_runtime.errors.invoke import InvokeError
 from libs import helper
 from libs.helper import uuid_value
@@ -75,7 +80,7 @@ class CompletionMessageApi(Resource):
             raise ProviderModelCurrentlyNotSupportError()
         except InvokeError as e:
             raise CompletionRequestError(e.description)
-        except ValueError as e:
+        except (ValueError, AppInvokeQuotaExceededError) as e:
             raise e
         except Exception as e:
             logging.exception("internal server error.")
@@ -141,7 +146,7 @@ class ChatMessageApi(Resource):
             raise ProviderModelCurrentlyNotSupportError()
         except InvokeError as e:
             raise CompletionRequestError(e.description)
-        except ValueError as e:
+        except (ValueError, AppInvokeQuotaExceededError) as e:
             raise e
         except Exception as e:
             logging.exception("internal server error.")

+ 2 - 1
api/controllers/console/app/workflow.py

@@ -13,6 +13,7 @@ from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
 from core.app.apps.base_app_queue_manager import AppQueueManager
 from core.app.entities.app_invoke_entities import InvokeFrom
+from core.errors.error import AppInvokeQuotaExceededError
 from fields.workflow_fields import workflow_fields
 from fields.workflow_run_fields import workflow_run_node_execution_fields
 from libs import helper
@@ -279,7 +280,7 @@ class DraftWorkflowRunApi(Resource):
             )
 
             return helper.compact_generate_response(response)
-        except ValueError as e:
+        except (ValueError, AppInvokeQuotaExceededError) as e:
             raise e
         except Exception as e:
             logging.exception("internal server error.")

+ 8 - 3
api/controllers/service_api/app/completion.py

@@ -17,7 +17,12 @@ from controllers.service_api.app.error import (
 from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
 from core.app.apps.base_app_queue_manager import AppQueueManager
 from core.app.entities.app_invoke_entities import InvokeFrom
-from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
+from core.errors.error import (
+    AppInvokeQuotaExceededError,
+    ModelCurrentlyNotSupportError,
+    ProviderTokenNotInitError,
+    QuotaExceededError,
+)
 from core.model_runtime.errors.invoke import InvokeError
 from libs import helper
 from libs.helper import uuid_value
@@ -69,7 +74,7 @@ class CompletionApi(Resource):
             raise ProviderModelCurrentlyNotSupportError()
         except InvokeError as e:
             raise CompletionRequestError(e.description)
-        except ValueError as e:
+        except (ValueError, AppInvokeQuotaExceededError) as e:
             raise e
         except Exception as e:
             logging.exception("internal server error.")
@@ -132,7 +137,7 @@ class ChatApi(Resource):
             raise ProviderModelCurrentlyNotSupportError()
         except InvokeError as e:
             raise CompletionRequestError(e.description)
-        except ValueError as e:
+        except (ValueError, AppInvokeQuotaExceededError) as e:
             raise e
         except Exception as e:
             logging.exception("internal server error.")

+ 7 - 2
api/controllers/service_api/app/workflow.py

@@ -14,7 +14,12 @@ from controllers.service_api.app.error import (
 from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
 from core.app.apps.base_app_queue_manager import AppQueueManager
 from core.app.entities.app_invoke_entities import InvokeFrom
-from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
+from core.errors.error import (
+    AppInvokeQuotaExceededError,
+    ModelCurrentlyNotSupportError,
+    ProviderTokenNotInitError,
+    QuotaExceededError,
+)
 from core.model_runtime.errors.invoke import InvokeError
 from libs import helper
 from models.model import App, AppMode, EndUser
@@ -59,7 +64,7 @@ class WorkflowRunApi(Resource):
             raise ProviderModelCurrentlyNotSupportError()
         except InvokeError as e:
             raise CompletionRequestError(e.description)
-        except ValueError as e:
+        except (ValueError, AppInvokeQuotaExceededError) as e:
             raise e
         except Exception as e:
             logging.exception("internal server error.")

+ 1 - 0
api/core/app/features/rate_limiting/__init__.py

@@ -0,0 +1 @@
+from .rate_limit import RateLimit

+ 120 - 0
api/core/app/features/rate_limiting/rate_limit.py

@@ -0,0 +1,120 @@
+import logging
+import time
+import uuid
+from collections.abc import Generator
+from datetime import timedelta
+from typing import Optional, Union
+
+from core.errors.error import AppInvokeQuotaExceededError
+from extensions.ext_redis import redis_client
+
+logger = logging.getLogger(__name__)
+
+
+class RateLimit:
+    _MAX_ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:max_active_requests"
+    _ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:active_requests"
+    _UNLIMITED_REQUEST_ID = "unlimited_request_id"
+    _REQUEST_MAX_ALIVE_TIME = 10 * 60  # 10 minutes
+    _ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60  # recalculate request_count from request_detail every 5 minutes
+    _instance_dict = {}
+
+    def __new__(cls: type['RateLimit'], client_id: str, max_active_requests: int):
+        if client_id not in cls._instance_dict:
+            instance = super().__new__(cls)
+            cls._instance_dict[client_id] = instance
+        return cls._instance_dict[client_id]
+
+    def __init__(self, client_id: str, max_active_requests: int):
+        self.max_active_requests = max_active_requests
+        if hasattr(self, 'initialized'):
+            return
+        self.initialized = True
+        self.client_id = client_id
+        self.active_requests_key = self._ACTIVE_REQUESTS_KEY.format(client_id)
+        self.max_active_requests_key = self._MAX_ACTIVE_REQUESTS_KEY.format(client_id)
+        self.last_recalculate_time = float('-inf')
+        self.flush_cache(use_local_value=True)
+
+    def flush_cache(self, use_local_value=False):
+        self.last_recalculate_time = time.time()
+        # flush max active requests
+        if use_local_value or not redis_client.exists(self.max_active_requests_key):
+            with redis_client.pipeline() as pipe:
+                pipe.set(self.max_active_requests_key, self.max_active_requests)
+                pipe.expire(self.max_active_requests_key, timedelta(days=1))
+                pipe.execute()
+        else:
+            with redis_client.pipeline() as pipe:
+                self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode('utf-8'))
+                redis_client.expire(self.max_active_requests_key, timedelta(days=1))
+
+        # flush max active requests (in-transit request list)
+        if not redis_client.exists(self.active_requests_key):
+            return
+        request_details = redis_client.hgetall(self.active_requests_key)
+        redis_client.expire(self.active_requests_key, timedelta(days=1))
+        timeout_requests = [k for k, v in request_details.items() if
+                            time.time() - float(v.decode('utf-8')) > RateLimit._REQUEST_MAX_ALIVE_TIME]
+        if timeout_requests:
+            redis_client.hdel(self.active_requests_key, *timeout_requests)
+
+    def enter(self, request_id: Optional[str] = None) -> str:
+        if time.time() - self.last_recalculate_time > RateLimit._ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL:
+            self.flush_cache()
+        if self.max_active_requests <= 0:
+            return RateLimit._UNLIMITED_REQUEST_ID
+        if not request_id:
+            request_id = RateLimit.gen_request_key()
+
+        active_requests_count = redis_client.hlen(self.active_requests_key)
+        if active_requests_count >= self.max_active_requests:
+            raise AppInvokeQuotaExceededError("Too many requests. Please try again later. The current maximum "
+                                              "concurrent requests allowed is {}.".format(self.max_active_requests))
+        redis_client.hset(self.active_requests_key, request_id, str(time.time()))
+        return request_id
+
+    def exit(self, request_id: str):
+        if request_id == RateLimit._UNLIMITED_REQUEST_ID:
+            return
+        redis_client.hdel(self.active_requests_key, request_id)
+
+    @staticmethod
+    def gen_request_key() -> str:
+        return str(uuid.uuid4())
+
+    def generate(self, generator: Union[Generator, callable, dict], request_id: str):
+        if isinstance(generator, dict):
+            return generator
+        else:
+            return RateLimitGenerator(self, generator, request_id)
+
+
+class RateLimitGenerator:
+    def __init__(self, rate_limit: RateLimit, generator: Union[Generator, callable], request_id: str):
+        self.rate_limit = rate_limit
+        if callable(generator):
+            self.generator = generator()
+        else:
+            self.generator = generator
+        self.request_id = request_id
+        self.closed = False
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        if self.closed:
+            raise StopIteration
+        try:
+            return next(self.generator)
+        except StopIteration:
+            self.close()
+            raise
+
+    def close(self):
+        if not self.closed:
+            self.closed = True
+            self.rate_limit.exit(self.request_id)
+            if self.generator is not None and hasattr(self.generator, 'close'):
+                self.generator.close()

+ 7 - 0
api/core/errors/error.py

@@ -31,6 +31,13 @@ class QuotaExceededError(Exception):
     description = "Quota Exceeded"
 
 
+class AppInvokeQuotaExceededError(Exception):
+    """
+    Custom exception raised when the quota for an app has been exceeded.
+    """
+    description = "App Invoke Quota Exceeded"
+
+
 class ModelCurrentlyNotSupportError(Exception):
     """
     Custom exception raised when the model not support

+ 1 - 0
api/fields/app_fields.py

@@ -72,6 +72,7 @@ tag_fields = {
 app_partial_fields = {
     'id': fields.String,
     'name': fields.String,
+    'max_active_requests': fields.Raw(),
     'description': fields.String(attribute='desc_or_prompt'),
     'mode': fields.String(attribute='mode_compatible_with_agent'),
     'icon': fields.String,

+ 9 - 0
api/libs/external_api.py

@@ -6,6 +6,8 @@ from flask_restful import Api, http_status_message
 from werkzeug.datastructures import Headers
 from werkzeug.exceptions import HTTPException
 
+from core.errors.error import AppInvokeQuotaExceededError
+
 
 class ExternalApi(Api):
 
@@ -43,6 +45,13 @@ class ExternalApi(Api):
                 'message': str(e),
                 'status': status_code
             }
+        elif isinstance(e, AppInvokeQuotaExceededError):
+            status_code = 429
+            default_data = {
+                'code': 'too_many_requests',
+                'message': str(e),
+                'status': status_code
+            }
         else:
             status_code = 500
             default_data = {

+ 33 - 0
api/migrations/versions/408176b91ad3_add_max_active_requests.py

@@ -0,0 +1,33 @@
+"""'add_max_active_requests'
+
+Revision ID: 408176b91ad3
+Revises: 7e6a8693e07a
+Create Date: 2024-07-04 09:25:14.029023
+
+"""
+import sqlalchemy as sa
+from alembic import op
+
+import models as models
+
+# revision identifiers, used by Alembic.
+revision = '408176b91ad3'
+down_revision = '161cadc1af8d'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('apps', schema=None) as batch_op:
+        batch_op.add_column(sa.Column('max_active_requests', sa.Integer(), nullable=True))
+
+    # ### end Alembic commands ###
+
+
+def downgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('apps', schema=None) as batch_op:
+        batch_op.drop_column('max_active_requests')
+
+    # ### end Alembic commands ###

+ 1 - 0
api/models/model.py

@@ -74,6 +74,7 @@ class App(db.Model):
     is_public = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
     is_universal = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
     tracing = db.Column(db.Text, nullable=True)
+    max_active_requests = db.Column(db.Integer, nullable=True)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
 

+ 63 - 46
api/services/app_generate_service.py

@@ -7,6 +7,7 @@ from core.app.apps.chat.app_generator import ChatAppGenerator
 from core.app.apps.completion.app_generator import CompletionAppGenerator
 from core.app.apps.workflow.app_generator import WorkflowAppGenerator
 from core.app.entities.app_invoke_entities import InvokeFrom
+from core.app.features.rate_limiting import RateLimit
 from models.model import Account, App, AppMode, EndUser
 from services.workflow_service import WorkflowService
 
@@ -29,52 +30,68 @@ class AppGenerateService:
         :param streaming: streaming
         :return:
         """
-        if app_model.mode == AppMode.COMPLETION.value:
-            return CompletionAppGenerator().generate(
-                app_model=app_model,
-                user=user,
-                args=args,
-                invoke_from=invoke_from,
-                stream=streaming
-            )
-        elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
-            return AgentChatAppGenerator().generate(
-                app_model=app_model,
-                user=user,
-                args=args,
-                invoke_from=invoke_from,
-                stream=streaming
-            )
-        elif app_model.mode == AppMode.CHAT.value:
-            return ChatAppGenerator().generate(
-                app_model=app_model,
-                user=user,
-                args=args,
-                invoke_from=invoke_from,
-                stream=streaming
-            )
-        elif app_model.mode == AppMode.ADVANCED_CHAT.value:
-            workflow = cls._get_workflow(app_model, invoke_from)
-            return AdvancedChatAppGenerator().generate(
-                app_model=app_model,
-                workflow=workflow,
-                user=user,
-                args=args,
-                invoke_from=invoke_from,
-                stream=streaming
-            )
-        elif app_model.mode == AppMode.WORKFLOW.value:
-            workflow = cls._get_workflow(app_model, invoke_from)
-            return WorkflowAppGenerator().generate(
-                app_model=app_model,
-                workflow=workflow,
-                user=user,
-                args=args,
-                invoke_from=invoke_from,
-                stream=streaming
-            )
-        else:
-            raise ValueError(f'Invalid app mode {app_model.mode}')
+        max_active_request = AppGenerateService._get_max_active_requests(app_model)
+        rate_limit = RateLimit(app_model.id, max_active_request)
+        request_id = RateLimit.gen_request_key()
+        try:
+            request_id = rate_limit.enter(request_id)
+            if app_model.mode == AppMode.COMPLETION.value:
+                return rate_limit.generate(CompletionAppGenerator().generate(
+                    app_model=app_model,
+                    user=user,
+                    args=args,
+                    invoke_from=invoke_from,
+                    stream=streaming
+                ), request_id)
+            elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
+                return rate_limit.generate(AgentChatAppGenerator().generate(
+                    app_model=app_model,
+                    user=user,
+                    args=args,
+                    invoke_from=invoke_from,
+                    stream=streaming
+                ), request_id)
+            elif app_model.mode == AppMode.CHAT.value:
+                return rate_limit.generate(ChatAppGenerator().generate(
+                    app_model=app_model,
+                    user=user,
+                    args=args,
+                    invoke_from=invoke_from,
+                    stream=streaming
+                ), request_id)
+            elif app_model.mode == AppMode.ADVANCED_CHAT.value:
+                workflow = cls._get_workflow(app_model, invoke_from)
+                return rate_limit.generate(AdvancedChatAppGenerator().generate(
+                    app_model=app_model,
+                    workflow=workflow,
+                    user=user,
+                    args=args,
+                    invoke_from=invoke_from,
+                    stream=streaming
+                ), request_id)
+            elif app_model.mode == AppMode.WORKFLOW.value:
+                workflow = cls._get_workflow(app_model, invoke_from)
+                return rate_limit.generate(WorkflowAppGenerator().generate(
+                    app_model=app_model,
+                    workflow=workflow,
+                    user=user,
+                    args=args,
+                    invoke_from=invoke_from,
+                    stream=streaming
+                ), request_id)
+            else:
+                raise ValueError(f'Invalid app mode {app_model.mode}')
+        finally:
+            if not streaming:
+                rate_limit.exit(request_id)
+
+    @staticmethod
+    def _get_max_active_requests(app_model: App) -> int:
+        max_active_requests = app_model.max_active_requests
+        if app_model.max_active_requests is None:
+            from flask import current_app
+            max_active_requests = int(current_app.config['APP_MAX_ACTIVE_REQUESTS'])
+        return max_active_requests
 
     @classmethod
     def generate_single_iteration(cls, app_model: App,

+ 5 - 0
api/services/app_service.py

@@ -10,6 +10,7 @@ from flask_sqlalchemy.pagination import Pagination
 
 from constants.model_template import default_app_templates
 from core.agent.entities import AgentToolEntity
+from core.app.features.rate_limiting import RateLimit
 from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
@@ -324,11 +325,15 @@ class AppService:
         """
         app.name = args.get('name')
         app.description = args.get('description', '')
+        app.max_active_requests = args.get('max_active_requests')
         app.icon = args.get('icon')
         app.icon_background = args.get('icon_background')
         app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
         db.session.commit()
 
+        if app.max_active_requests is not None:
+            rate_limit = RateLimit(app.id, app.max_active_requests)
+            rate_limit.flush_cache(use_local_value=True)
         return app
 
     def update_app_name(self, app: App, name: str) -> App:

+ 2 - 0
docker-legacy/docker-compose.yaml

@@ -39,6 +39,8 @@ services:
       # File Access Time specifies a time interval in seconds for the file to be accessed.
       # The default value is 300 seconds.
       FILES_ACCESS_TIMEOUT: 300
+      # The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer.
+      APP_MAX_ACTIVE_REQUESTS: ${FILES_ACCESS_TIMEOUT:-0}
       # When enabled, migrations will be executed prior to application startup and the application will start after the migrations have completed.
       MIGRATION_ENABLED: 'true'
       # The configurations of postgres database connection.

+ 3 - 0
docker/.env.example

@@ -91,6 +91,9 @@ MIGRATION_ENABLED=true
 # The default value is 300 seconds.
 FILES_ACCESS_TIMEOUT=300
 
+# The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer.
+APP_MAX_ACTIVE_REQUESTS=0
+
 # ------------------------------
 # Container Startup Related Configuration
 # Only effective when starting with docker image or docker-compose.

+ 1 - 0
docker/docker-compose.yaml

@@ -12,6 +12,7 @@ x-shared-env: &shared-api-worker-env
   OPENAI_API_BASE: ${OPENAI_API_BASE:-https://api.openai.com/v1}
   FILES_URL: ${FILES_URL:-}
   FILES_ACCESS_TIMEOUT: ${FILES_ACCESS_TIMEOUT:-300}
+  APP_MAX_ACTIVE_REQUESTS: ${FILES_ACCESS_TIMEOUT:-0}
   MIGRATION_ENABLED: ${MIGRATION_ENABLED:-true}
   DEPLOY_ENV: ${DEPLOY_ENV:-PRODUCTION}
   DIFY_BIND_ADDRESS: ${DIFY_BIND_ADDRESS:-0.0.0.0}