|
@@ -0,0 +1,120 @@
|
|
|
+import logging
|
|
|
+import time
|
|
|
+import uuid
|
|
|
+from collections.abc import Generator
|
|
|
+from datetime import timedelta
|
|
|
+from typing import Optional, Union
|
|
|
+
|
|
|
+from core.errors.error import AppInvokeQuotaExceededError
|
|
|
+from extensions.ext_redis import redis_client
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+
|
|
|
+class RateLimit:
|
|
|
+ _MAX_ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:max_active_requests"
|
|
|
+ _ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:active_requests"
|
|
|
+ _UNLIMITED_REQUEST_ID = "unlimited_request_id"
|
|
|
+ _REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes
|
|
|
+ _ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
|
|
|
+ _instance_dict = {}
|
|
|
+
|
|
|
+ def __new__(cls: type['RateLimit'], client_id: str, max_active_requests: int):
|
|
|
+ if client_id not in cls._instance_dict:
|
|
|
+ instance = super().__new__(cls)
|
|
|
+ cls._instance_dict[client_id] = instance
|
|
|
+ return cls._instance_dict[client_id]
|
|
|
+
|
|
|
+ def __init__(self, client_id: str, max_active_requests: int):
|
|
|
+ self.max_active_requests = max_active_requests
|
|
|
+ if hasattr(self, 'initialized'):
|
|
|
+ return
|
|
|
+ self.initialized = True
|
|
|
+ self.client_id = client_id
|
|
|
+ self.active_requests_key = self._ACTIVE_REQUESTS_KEY.format(client_id)
|
|
|
+ self.max_active_requests_key = self._MAX_ACTIVE_REQUESTS_KEY.format(client_id)
|
|
|
+ self.last_recalculate_time = float('-inf')
|
|
|
+ self.flush_cache(use_local_value=True)
|
|
|
+
|
|
|
+ def flush_cache(self, use_local_value=False):
|
|
|
+ self.last_recalculate_time = time.time()
|
|
|
+ # flush max active requests
|
|
|
+ if use_local_value or not redis_client.exists(self.max_active_requests_key):
|
|
|
+ with redis_client.pipeline() as pipe:
|
|
|
+ pipe.set(self.max_active_requests_key, self.max_active_requests)
|
|
|
+ pipe.expire(self.max_active_requests_key, timedelta(days=1))
|
|
|
+ pipe.execute()
|
|
|
+ else:
|
|
|
+ with redis_client.pipeline() as pipe:
|
|
|
+ self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode('utf-8'))
|
|
|
+ redis_client.expire(self.max_active_requests_key, timedelta(days=1))
|
|
|
+
|
|
|
+ # flush max active requests (in-transit request list)
|
|
|
+ if not redis_client.exists(self.active_requests_key):
|
|
|
+ return
|
|
|
+ request_details = redis_client.hgetall(self.active_requests_key)
|
|
|
+ redis_client.expire(self.active_requests_key, timedelta(days=1))
|
|
|
+ timeout_requests = [k for k, v in request_details.items() if
|
|
|
+ time.time() - float(v.decode('utf-8')) > RateLimit._REQUEST_MAX_ALIVE_TIME]
|
|
|
+ if timeout_requests:
|
|
|
+ redis_client.hdel(self.active_requests_key, *timeout_requests)
|
|
|
+
|
|
|
+ def enter(self, request_id: Optional[str] = None) -> str:
|
|
|
+ if time.time() - self.last_recalculate_time > RateLimit._ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL:
|
|
|
+ self.flush_cache()
|
|
|
+ if self.max_active_requests <= 0:
|
|
|
+ return RateLimit._UNLIMITED_REQUEST_ID
|
|
|
+ if not request_id:
|
|
|
+ request_id = RateLimit.gen_request_key()
|
|
|
+
|
|
|
+ active_requests_count = redis_client.hlen(self.active_requests_key)
|
|
|
+ if active_requests_count >= self.max_active_requests:
|
|
|
+ raise AppInvokeQuotaExceededError("Too many requests. Please try again later. The current maximum "
|
|
|
+ "concurrent requests allowed is {}.".format(self.max_active_requests))
|
|
|
+ redis_client.hset(self.active_requests_key, request_id, str(time.time()))
|
|
|
+ return request_id
|
|
|
+
|
|
|
+ def exit(self, request_id: str):
|
|
|
+ if request_id == RateLimit._UNLIMITED_REQUEST_ID:
|
|
|
+ return
|
|
|
+ redis_client.hdel(self.active_requests_key, request_id)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def gen_request_key() -> str:
|
|
|
+ return str(uuid.uuid4())
|
|
|
+
|
|
|
+ def generate(self, generator: Union[Generator, callable, dict], request_id: str):
|
|
|
+ if isinstance(generator, dict):
|
|
|
+ return generator
|
|
|
+ else:
|
|
|
+ return RateLimitGenerator(self, generator, request_id)
|
|
|
+
|
|
|
+
|
|
|
+class RateLimitGenerator:
|
|
|
+ def __init__(self, rate_limit: RateLimit, generator: Union[Generator, callable], request_id: str):
|
|
|
+ self.rate_limit = rate_limit
|
|
|
+ if callable(generator):
|
|
|
+ self.generator = generator()
|
|
|
+ else:
|
|
|
+ self.generator = generator
|
|
|
+ self.request_id = request_id
|
|
|
+ self.closed = False
|
|
|
+
|
|
|
+ def __iter__(self):
|
|
|
+ return self
|
|
|
+
|
|
|
+ def __next__(self):
|
|
|
+ if self.closed:
|
|
|
+ raise StopIteration
|
|
|
+ try:
|
|
|
+ return next(self.generator)
|
|
|
+ except StopIteration:
|
|
|
+ self.close()
|
|
|
+ raise
|
|
|
+
|
|
|
+ def close(self):
|
|
|
+ if not self.closed:
|
|
|
+ self.closed = True
|
|
|
+ self.rate_limit.exit(self.request_id)
|
|
|
+ if self.generator is not None and hasattr(self.generator, 'close'):
|
|
|
+ self.generator.close()
|