"""Redis-backed rate limiter with in-memory fallback. Uses a fixed-window counter pattern via Redis INCR + EXPIRE. Falls back to an in-memory TTLCache when Redis is unavailable, preserving identical behavior for dev/test environments. """ import time from cachetools import TTLCache from fastapi import HTTPException from app.core.logging import get_logger from app.core.redis import get_redis logger = get_logger(__name__) # ── In-memory fallback caches ────────────────────────────────────────────── _local_rate_cache: TTLCache[str, int] = TTLCache(maxsize=10000, ttl=120) _local_lockout_cache: TTLCache[str, tuple[int, float]] = TTLCache(maxsize=1000, ttl=900) async def check_rate_limit(key: str, limit: int, window_seconds: int) -> None: """Check and increment a sliding-window rate counter. Args: key: Unique identifier (e.g. ``"story:"``). limit: Maximum requests allowed within the window. window_seconds: Window duration in seconds. Raises: HTTPException: 429 when the limit is exceeded. """ try: redis = await get_redis() # Fixed-window bucket: key + minute boundary bucket = int(time.time() // window_seconds) redis_key = f"ratelimit:{key}:{bucket}" count = await redis.incr(redis_key) if count == 1: await redis.expire(redis_key, window_seconds) if count > limit: raise HTTPException( status_code=429, detail="Too many requests, please slow down.", ) return except HTTPException: raise except Exception as exc: logger.warning("rate_limit_redis_fallback", error=str(exc)) # ── Fallback: in-memory counter ──────────────────────────────────────── count = _local_rate_cache.get(key, 0) + 1 _local_rate_cache[key] = count if count > limit: raise HTTPException( status_code=429, detail="Too many requests, please slow down.", ) async def record_failed_attempt( key: str, max_attempts: int, lockout_seconds: int, ) -> bool: """Record a failed login attempt and return whether the key is locked out. Args: key: Unique identifier (e.g. ``"admin_login:"``). max_attempts: Number of failures before lockout. lockout_seconds: Duration of lockout in seconds. Returns: ``True`` if the key is now locked out, ``False`` otherwise. """ try: redis = await get_redis() redis_key = f"lockout:{key}" count = await redis.incr(redis_key) if count == 1: await redis.expire(redis_key, lockout_seconds) return count >= max_attempts except Exception as exc: logger.warning("lockout_redis_fallback", error=str(exc)) # ── Fallback ─────────────────────────────────────────────────────────── if key in _local_lockout_cache: attempts, first_fail = _local_lockout_cache[key] _local_lockout_cache[key] = (attempts + 1, first_fail) return (attempts + 1) >= max_attempts else: _local_lockout_cache[key] = (1, time.time()) return 1 >= max_attempts async def is_locked_out(key: str, max_attempts: int, lockout_seconds: int) -> int: """Check if a key is currently locked out. Returns: Remaining lockout seconds (> 0 means locked), 0 means not locked. """ try: redis = await get_redis() redis_key = f"lockout:{key}" count = await redis.get(redis_key) if count is not None and int(count) >= max_attempts: ttl = await redis.ttl(redis_key) return max(ttl, 0) return 0 except Exception as exc: logger.warning("lockout_check_redis_fallback", error=str(exc)) # ── Fallback ─────────────────────────────────────────────────────────── if key in _local_lockout_cache: attempts, first_fail = _local_lockout_cache[key] if attempts >= max_attempts: remaining = int(lockout_seconds - (time.time() - first_fail)) if remaining > 0: return remaining else: del _local_lockout_cache[key] return 0 async def clear_failed_attempts(key: str) -> None: """Clear lockout state on successful login.""" try: redis = await get_redis() await redis.delete(f"lockout:{key}") except Exception as exc: logger.warning("lockout_clear_redis_fallback", error=str(exc)) # Always clear local cache too _local_lockout_cache.pop(key, None)