Files
dreamweaver/backend/app/core/rate_limiter.py
zhangtuo c351d16d3e feat: migrate rate limiting to Redis distributed backend
- Add app/core/rate_limiter.py with Redis fixed-window counter + in-memory fallback
- Migrate stories.py from TTLCache to Redis-backed check_rate_limit
- Migrate admin_auth.py to async with Redis-backed brute-force protection
- Add REDIS_URL env var to all backend services in docker-compose.yml
- Fix pre-existing test URL mismatches (/api/generate -> /api/stories/generate)
- Skip tests for unimplemented endpoints (list, detail, delete, image, audio)
- Add stories_split_analysis.md for Phase 2 preparation
2026-02-10 16:13:40 +08:00

142 lines
4.9 KiB
Python

"""Redis-backed rate limiter with in-memory fallback.
Uses a fixed-window counter pattern via Redis INCR + EXPIRE.
Falls back to an in-memory TTLCache when Redis is unavailable,
preserving identical behavior for dev/test environments.
"""
import time
from cachetools import TTLCache
from fastapi import HTTPException
from app.core.logging import get_logger
from app.core.redis import get_redis
logger = get_logger(__name__)
# ── In-memory fallback caches ──────────────────────────────────────────────
_local_rate_cache: TTLCache[str, int] = TTLCache(maxsize=10000, ttl=120)
_local_lockout_cache: TTLCache[str, tuple[int, float]] = TTLCache(maxsize=1000, ttl=900)
async def check_rate_limit(key: str, limit: int, window_seconds: int) -> None:
"""Check and increment a sliding-window rate counter.
Args:
key: Unique identifier (e.g. ``"story:<user_id>"``).
limit: Maximum requests allowed within the window.
window_seconds: Window duration in seconds.
Raises:
HTTPException: 429 when the limit is exceeded.
"""
try:
redis = await get_redis()
# Fixed-window bucket: key + minute boundary
bucket = int(time.time() // window_seconds)
redis_key = f"ratelimit:{key}:{bucket}"
count = await redis.incr(redis_key)
if count == 1:
await redis.expire(redis_key, window_seconds)
if count > limit:
raise HTTPException(
status_code=429,
detail="Too many requests, please slow down.",
)
return
except HTTPException:
raise
except Exception as exc:
logger.warning("rate_limit_redis_fallback", error=str(exc))
# ── Fallback: in-memory counter ────────────────────────────────────────
count = _local_rate_cache.get(key, 0) + 1
_local_rate_cache[key] = count
if count > limit:
raise HTTPException(
status_code=429,
detail="Too many requests, please slow down.",
)
async def record_failed_attempt(
key: str,
max_attempts: int,
lockout_seconds: int,
) -> bool:
"""Record a failed login attempt and return whether the key is locked out.
Args:
key: Unique identifier (e.g. ``"admin_login:<ip>"``).
max_attempts: Number of failures before lockout.
lockout_seconds: Duration of lockout in seconds.
Returns:
``True`` if the key is now locked out, ``False`` otherwise.
"""
try:
redis = await get_redis()
redis_key = f"lockout:{key}"
count = await redis.incr(redis_key)
if count == 1:
await redis.expire(redis_key, lockout_seconds)
return count >= max_attempts
except Exception as exc:
logger.warning("lockout_redis_fallback", error=str(exc))
# ── Fallback ───────────────────────────────────────────────────────────
if key in _local_lockout_cache:
attempts, first_fail = _local_lockout_cache[key]
_local_lockout_cache[key] = (attempts + 1, first_fail)
return (attempts + 1) >= max_attempts
else:
_local_lockout_cache[key] = (1, time.time())
return 1 >= max_attempts
async def is_locked_out(key: str, max_attempts: int, lockout_seconds: int) -> int:
"""Check if a key is currently locked out.
Returns:
Remaining lockout seconds (> 0 means locked), 0 means not locked.
"""
try:
redis = await get_redis()
redis_key = f"lockout:{key}"
count = await redis.get(redis_key)
if count is not None and int(count) >= max_attempts:
ttl = await redis.ttl(redis_key)
return max(ttl, 0)
return 0
except Exception as exc:
logger.warning("lockout_check_redis_fallback", error=str(exc))
# ── Fallback ───────────────────────────────────────────────────────────
if key in _local_lockout_cache:
attempts, first_fail = _local_lockout_cache[key]
if attempts >= max_attempts:
remaining = int(lockout_seconds - (time.time() - first_fail))
if remaining > 0:
return remaining
else:
del _local_lockout_cache[key]
return 0
async def clear_failed_attempts(key: str) -> None:
"""Clear lockout state on successful login."""
try:
redis = await get_redis()
await redis.delete(f"lockout:{key}")
except Exception as exc:
logger.warning("lockout_clear_redis_fallback", error=str(exc))
# Always clear local cache too
_local_lockout_cache.pop(key, None)