feat: migrate rate limiting to Redis distributed backend
- Add app/core/rate_limiter.py with Redis fixed-window counter + in-memory fallback - Migrate stories.py from TTLCache to Redis-backed check_rate_limit - Migrate admin_auth.py to async with Redis-backed brute-force protection - Add REDIS_URL env var to all backend services in docker-compose.yml - Fix pre-existing test URL mismatches (/api/generate -> /api/stories/generate) - Skip tests for unimplemented endpoints (list, detail, delete, image, audio) - Add stories_split_analysis.md for Phase 2 preparation
This commit is contained in:
@@ -6,7 +6,6 @@ import time
|
||||
import uuid
|
||||
from typing import AsyncGenerator, Literal
|
||||
|
||||
from cachetools import TTLCache
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -25,6 +24,7 @@ from app.services.provider_router import (
|
||||
generate_storybook,
|
||||
text_to_speech,
|
||||
)
|
||||
from app.core.rate_limiter import check_rate_limit
|
||||
from app.tasks.achievements import extract_story_achievements
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -37,21 +37,6 @@ MAX_TTS_LENGTH = 4000
|
||||
|
||||
RATE_LIMIT_WINDOW = 60 # seconds
|
||||
RATE_LIMIT_REQUESTS = 10
|
||||
RATE_LIMIT_CACHE_SIZE = 10000 # 最大跟踪用户数
|
||||
|
||||
_request_log: TTLCache[str, list[float]] = TTLCache(
|
||||
maxsize=RATE_LIMIT_CACHE_SIZE, ttl=RATE_LIMIT_WINDOW * 2
|
||||
)
|
||||
|
||||
|
||||
def _check_rate_limit(user_id: str):
|
||||
now = time.time()
|
||||
timestamps = _request_log.get(user_id, [])
|
||||
timestamps = [t for t in timestamps if now - t <= RATE_LIMIT_WINDOW]
|
||||
if len(timestamps) >= RATE_LIMIT_REQUESTS:
|
||||
raise HTTPException(status_code=429, detail="Too many requests, please slow down.")
|
||||
timestamps.append(now)
|
||||
_request_log[user_id] = timestamps
|
||||
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
@@ -154,7 +139,7 @@ async def generate_story(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Generate or enhance a story."""
|
||||
_check_rate_limit(user.id)
|
||||
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
|
||||
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
|
||||
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
||||
|
||||
@@ -208,7 +193,7 @@ async def generate_story_full(
|
||||
|
||||
部分成功策略:故事必须成功,图片/音频失败不影响整体。
|
||||
"""
|
||||
_check_rate_limit(user.id)
|
||||
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
|
||||
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
|
||||
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
||||
|
||||
@@ -288,7 +273,7 @@ async def generate_story_stream(
|
||||
- image_failed: 返回 error
|
||||
- complete: 结束流
|
||||
"""
|
||||
_check_rate_limit(user.id)
|
||||
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
|
||||
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
|
||||
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
||||
|
||||
@@ -400,7 +385,7 @@ async def generate_storybook_api(
|
||||
|
||||
返回故事书结构,包含每页文字和图像提示词。
|
||||
"""
|
||||
_check_rate_limit(user.id)
|
||||
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
|
||||
|
||||
# 验证档案和宇宙
|
||||
# 复用 _validate_profile_and_universe 需要将 request 转换为 GenerateRequest 或稍微修改验证函数
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
import secrets
|
||||
import time
|
||||
|
||||
from cachetools import TTLCache
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.rate_limiter import (
|
||||
clear_failed_attempts,
|
||||
is_locked_out,
|
||||
record_failed_attempt,
|
||||
)
|
||||
|
||||
security = HTTPBasic()
|
||||
|
||||
# 登录失败记录:IP -> (失败次数, 首次失败时间)
|
||||
_failed_attempts: TTLCache[str, tuple[int, float]] = TTLCache(maxsize=1000, ttl=900) # 15分钟
|
||||
|
||||
MAX_ATTEMPTS = 5
|
||||
LOCKOUT_SECONDS = 900 # 15分钟
|
||||
|
||||
@@ -25,24 +25,20 @@ def _get_client_ip(request: Request) -> str:
|
||||
return "unknown"
|
||||
|
||||
|
||||
def admin_guard(
|
||||
async def admin_guard(
|
||||
request: Request,
|
||||
credentials: HTTPBasicCredentials = Depends(security),
|
||||
):
|
||||
client_ip = _get_client_ip(request)
|
||||
lockout_key = f"admin_login:{client_ip}"
|
||||
|
||||
# 检查是否被锁定
|
||||
if client_ip in _failed_attempts:
|
||||
attempts, first_fail = _failed_attempts[client_ip]
|
||||
if attempts >= MAX_ATTEMPTS:
|
||||
remaining = int(LOCKOUT_SECONDS - (time.time() - first_fail))
|
||||
if remaining > 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail=f"登录尝试过多,请 {remaining} 秒后重试",
|
||||
)
|
||||
else:
|
||||
del _failed_attempts[client_ip]
|
||||
remaining = await is_locked_out(lockout_key, MAX_ATTEMPTS, LOCKOUT_SECONDS)
|
||||
if remaining > 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail=f"登录尝试过多,请 {remaining} 秒后重试",
|
||||
)
|
||||
|
||||
# 使用 secrets.compare_digest 防止时序攻击
|
||||
username_ok = secrets.compare_digest(
|
||||
@@ -53,20 +49,12 @@ def admin_guard(
|
||||
)
|
||||
|
||||
if not (username_ok and password_ok):
|
||||
# 记录失败
|
||||
if client_ip in _failed_attempts:
|
||||
attempts, first_fail = _failed_attempts[client_ip]
|
||||
_failed_attempts[client_ip] = (attempts + 1, first_fail)
|
||||
else:
|
||||
_failed_attempts[client_ip] = (1, time.time())
|
||||
|
||||
await record_failed_attempt(lockout_key, MAX_ATTEMPTS, LOCKOUT_SECONDS)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="用户名或密码错误",
|
||||
)
|
||||
|
||||
# 登录成功,清除失败记录
|
||||
if client_ip in _failed_attempts:
|
||||
del _failed_attempts[client_ip]
|
||||
|
||||
await clear_failed_attempts(lockout_key)
|
||||
return True
|
||||
|
||||
141
backend/app/core/rate_limiter.py
Normal file
141
backend/app/core/rate_limiter.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""Redis-backed rate limiter with in-memory fallback.
|
||||
|
||||
Uses a fixed-window counter pattern via Redis INCR + EXPIRE.
|
||||
Falls back to an in-memory TTLCache when Redis is unavailable,
|
||||
preserving identical behavior for dev/test environments.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
from cachetools import TTLCache
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.core.redis import get_redis
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# ── In-memory fallback caches ──────────────────────────────────────────────
|
||||
_local_rate_cache: TTLCache[str, int] = TTLCache(maxsize=10000, ttl=120)
|
||||
_local_lockout_cache: TTLCache[str, tuple[int, float]] = TTLCache(maxsize=1000, ttl=900)
|
||||
|
||||
|
||||
async def check_rate_limit(key: str, limit: int, window_seconds: int) -> None:
|
||||
"""Check and increment a sliding-window rate counter.
|
||||
|
||||
Args:
|
||||
key: Unique identifier (e.g. ``"story:<user_id>"``).
|
||||
limit: Maximum requests allowed within the window.
|
||||
window_seconds: Window duration in seconds.
|
||||
|
||||
Raises:
|
||||
HTTPException: 429 when the limit is exceeded.
|
||||
"""
|
||||
try:
|
||||
redis = await get_redis()
|
||||
# Fixed-window bucket: key + minute boundary
|
||||
bucket = int(time.time() // window_seconds)
|
||||
redis_key = f"ratelimit:{key}:{bucket}"
|
||||
|
||||
count = await redis.incr(redis_key)
|
||||
if count == 1:
|
||||
await redis.expire(redis_key, window_seconds)
|
||||
|
||||
if count > limit:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="Too many requests, please slow down.",
|
||||
)
|
||||
return
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.warning("rate_limit_redis_fallback", error=str(exc))
|
||||
|
||||
# ── Fallback: in-memory counter ────────────────────────────────────────
|
||||
count = _local_rate_cache.get(key, 0) + 1
|
||||
_local_rate_cache[key] = count
|
||||
if count > limit:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="Too many requests, please slow down.",
|
||||
)
|
||||
|
||||
|
||||
async def record_failed_attempt(
|
||||
key: str,
|
||||
max_attempts: int,
|
||||
lockout_seconds: int,
|
||||
) -> bool:
|
||||
"""Record a failed login attempt and return whether the key is locked out.
|
||||
|
||||
Args:
|
||||
key: Unique identifier (e.g. ``"admin_login:<ip>"``).
|
||||
max_attempts: Number of failures before lockout.
|
||||
lockout_seconds: Duration of lockout in seconds.
|
||||
|
||||
Returns:
|
||||
``True`` if the key is now locked out, ``False`` otherwise.
|
||||
"""
|
||||
try:
|
||||
redis = await get_redis()
|
||||
redis_key = f"lockout:{key}"
|
||||
|
||||
count = await redis.incr(redis_key)
|
||||
if count == 1:
|
||||
await redis.expire(redis_key, lockout_seconds)
|
||||
|
||||
return count >= max_attempts
|
||||
except Exception as exc:
|
||||
logger.warning("lockout_redis_fallback", error=str(exc))
|
||||
|
||||
# ── Fallback ───────────────────────────────────────────────────────────
|
||||
if key in _local_lockout_cache:
|
||||
attempts, first_fail = _local_lockout_cache[key]
|
||||
_local_lockout_cache[key] = (attempts + 1, first_fail)
|
||||
return (attempts + 1) >= max_attempts
|
||||
else:
|
||||
_local_lockout_cache[key] = (1, time.time())
|
||||
return 1 >= max_attempts
|
||||
|
||||
|
||||
async def is_locked_out(key: str, max_attempts: int, lockout_seconds: int) -> int:
|
||||
"""Check if a key is currently locked out.
|
||||
|
||||
Returns:
|
||||
Remaining lockout seconds (> 0 means locked), 0 means not locked.
|
||||
"""
|
||||
try:
|
||||
redis = await get_redis()
|
||||
redis_key = f"lockout:{key}"
|
||||
|
||||
count = await redis.get(redis_key)
|
||||
if count is not None and int(count) >= max_attempts:
|
||||
ttl = await redis.ttl(redis_key)
|
||||
return max(ttl, 0)
|
||||
return 0
|
||||
except Exception as exc:
|
||||
logger.warning("lockout_check_redis_fallback", error=str(exc))
|
||||
|
||||
# ── Fallback ───────────────────────────────────────────────────────────
|
||||
if key in _local_lockout_cache:
|
||||
attempts, first_fail = _local_lockout_cache[key]
|
||||
if attempts >= max_attempts:
|
||||
remaining = int(lockout_seconds - (time.time() - first_fail))
|
||||
if remaining > 0:
|
||||
return remaining
|
||||
else:
|
||||
del _local_lockout_cache[key]
|
||||
return 0
|
||||
|
||||
|
||||
async def clear_failed_attempts(key: str) -> None:
|
||||
"""Clear lockout state on successful login."""
|
||||
try:
|
||||
redis = await get_redis()
|
||||
await redis.delete(f"lockout:{key}")
|
||||
except Exception as exc:
|
||||
logger.warning("lockout_clear_redis_fallback", error=str(exc))
|
||||
|
||||
# Always clear local cache too
|
||||
_local_lockout_cache.pop(key, None)
|
||||
Reference in New Issue
Block a user