feat: migrate rate limiting to Redis distributed backend

- Add app/core/rate_limiter.py with Redis fixed-window counter + in-memory fallback
- Migrate stories.py from TTLCache to Redis-backed check_rate_limit
- Migrate admin_auth.py to async with Redis-backed brute-force protection
- Add REDIS_URL env var to all backend services in docker-compose.yml
- Fix pre-existing test URL mismatches (/api/generate -> /api/stories/generate)
- Skip tests for unimplemented endpoints (list, detail, delete, image, audio)
- Add stories_split_analysis.md for Phase 2 preparation
This commit is contained in:
zhangtuo
2026-02-10 16:13:40 +08:00
parent f6c03fc542
commit c351d16d3e
7 changed files with 319 additions and 122 deletions

View File

@@ -6,7 +6,6 @@ import time
import uuid
from typing import AsyncGenerator, Literal
from cachetools import TTLCache
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import Response
from pydantic import BaseModel, Field
@@ -25,6 +24,7 @@ from app.services.provider_router import (
generate_storybook,
text_to_speech,
)
from app.core.rate_limiter import check_rate_limit
from app.tasks.achievements import extract_story_achievements
logger = get_logger(__name__)
@@ -37,21 +37,6 @@ MAX_TTS_LENGTH = 4000
RATE_LIMIT_WINDOW = 60 # seconds
RATE_LIMIT_REQUESTS = 10
RATE_LIMIT_CACHE_SIZE = 10000 # 最大跟踪用户数
_request_log: TTLCache[str, list[float]] = TTLCache(
maxsize=RATE_LIMIT_CACHE_SIZE, ttl=RATE_LIMIT_WINDOW * 2
)
def _check_rate_limit(user_id: str):
now = time.time()
timestamps = _request_log.get(user_id, [])
timestamps = [t for t in timestamps if now - t <= RATE_LIMIT_WINDOW]
if len(timestamps) >= RATE_LIMIT_REQUESTS:
raise HTTPException(status_code=429, detail="Too many requests, please slow down.")
timestamps.append(now)
_request_log[user_id] = timestamps
class GenerateRequest(BaseModel):
@@ -154,7 +139,7 @@ async def generate_story(
db: AsyncSession = Depends(get_db),
):
"""Generate or enhance a story."""
_check_rate_limit(user.id)
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
@@ -208,7 +193,7 @@ async def generate_story_full(
部分成功策略:故事必须成功,图片/音频失败不影响整体。
"""
_check_rate_limit(user.id)
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
@@ -288,7 +273,7 @@ async def generate_story_stream(
- image_failed: 返回 error
- complete: 结束流
"""
_check_rate_limit(user.id)
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
@@ -400,7 +385,7 @@ async def generate_storybook_api(
返回故事书结构,包含每页文字和图像提示词。
"""
_check_rate_limit(user.id)
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
# 验证档案和宇宙
# 复用 _validate_profile_and_universe 需要将 request 转换为 GenerateRequest 或稍微修改验证函数

View File

@@ -1,17 +1,17 @@
import secrets
import time
from cachetools import TTLCache
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from app.core.config import settings
from app.core.rate_limiter import (
clear_failed_attempts,
is_locked_out,
record_failed_attempt,
)
security = HTTPBasic()
# 登录失败记录IP -> (失败次数, 首次失败时间)
_failed_attempts: TTLCache[str, tuple[int, float]] = TTLCache(maxsize=1000, ttl=900) # 15分钟
MAX_ATTEMPTS = 5
LOCKOUT_SECONDS = 900 # 15分钟
@@ -25,24 +25,20 @@ def _get_client_ip(request: Request) -> str:
return "unknown"
def admin_guard(
async def admin_guard(
request: Request,
credentials: HTTPBasicCredentials = Depends(security),
):
client_ip = _get_client_ip(request)
lockout_key = f"admin_login:{client_ip}"
# 检查是否被锁定
if client_ip in _failed_attempts:
attempts, first_fail = _failed_attempts[client_ip]
if attempts >= MAX_ATTEMPTS:
remaining = int(LOCKOUT_SECONDS - (time.time() - first_fail))
if remaining > 0:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f"登录尝试过多,请 {remaining} 秒后重试",
)
else:
del _failed_attempts[client_ip]
remaining = await is_locked_out(lockout_key, MAX_ATTEMPTS, LOCKOUT_SECONDS)
if remaining > 0:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f"登录尝试过多,请 {remaining} 秒后重试",
)
# 使用 secrets.compare_digest 防止时序攻击
username_ok = secrets.compare_digest(
@@ -53,20 +49,12 @@ def admin_guard(
)
if not (username_ok and password_ok):
# 记录失败
if client_ip in _failed_attempts:
attempts, first_fail = _failed_attempts[client_ip]
_failed_attempts[client_ip] = (attempts + 1, first_fail)
else:
_failed_attempts[client_ip] = (1, time.time())
await record_failed_attempt(lockout_key, MAX_ATTEMPTS, LOCKOUT_SECONDS)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误",
)
# 登录成功,清除失败记录
if client_ip in _failed_attempts:
del _failed_attempts[client_ip]
await clear_failed_attempts(lockout_key)
return True

View File

@@ -0,0 +1,141 @@
"""Redis-backed rate limiter with in-memory fallback.
Uses a fixed-window counter pattern via Redis INCR + EXPIRE.
Falls back to an in-memory TTLCache when Redis is unavailable,
preserving identical behavior for dev/test environments.
"""
import time
from cachetools import TTLCache
from fastapi import HTTPException
from app.core.logging import get_logger
from app.core.redis import get_redis
logger = get_logger(__name__)
# ── In-memory fallback caches ──────────────────────────────────────────────
_local_rate_cache: TTLCache[str, int] = TTLCache(maxsize=10000, ttl=120)
_local_lockout_cache: TTLCache[str, tuple[int, float]] = TTLCache(maxsize=1000, ttl=900)
async def check_rate_limit(key: str, limit: int, window_seconds: int) -> None:
"""Check and increment a sliding-window rate counter.
Args:
key: Unique identifier (e.g. ``"story:<user_id>"``).
limit: Maximum requests allowed within the window.
window_seconds: Window duration in seconds.
Raises:
HTTPException: 429 when the limit is exceeded.
"""
try:
redis = await get_redis()
# Fixed-window bucket: key + minute boundary
bucket = int(time.time() // window_seconds)
redis_key = f"ratelimit:{key}:{bucket}"
count = await redis.incr(redis_key)
if count == 1:
await redis.expire(redis_key, window_seconds)
if count > limit:
raise HTTPException(
status_code=429,
detail="Too many requests, please slow down.",
)
return
except HTTPException:
raise
except Exception as exc:
logger.warning("rate_limit_redis_fallback", error=str(exc))
# ── Fallback: in-memory counter ────────────────────────────────────────
count = _local_rate_cache.get(key, 0) + 1
_local_rate_cache[key] = count
if count > limit:
raise HTTPException(
status_code=429,
detail="Too many requests, please slow down.",
)
async def record_failed_attempt(
key: str,
max_attempts: int,
lockout_seconds: int,
) -> bool:
"""Record a failed login attempt and return whether the key is locked out.
Args:
key: Unique identifier (e.g. ``"admin_login:<ip>"``).
max_attempts: Number of failures before lockout.
lockout_seconds: Duration of lockout in seconds.
Returns:
``True`` if the key is now locked out, ``False`` otherwise.
"""
try:
redis = await get_redis()
redis_key = f"lockout:{key}"
count = await redis.incr(redis_key)
if count == 1:
await redis.expire(redis_key, lockout_seconds)
return count >= max_attempts
except Exception as exc:
logger.warning("lockout_redis_fallback", error=str(exc))
# ── Fallback ───────────────────────────────────────────────────────────
if key in _local_lockout_cache:
attempts, first_fail = _local_lockout_cache[key]
_local_lockout_cache[key] = (attempts + 1, first_fail)
return (attempts + 1) >= max_attempts
else:
_local_lockout_cache[key] = (1, time.time())
return 1 >= max_attempts
async def is_locked_out(key: str, max_attempts: int, lockout_seconds: int) -> int:
"""Check if a key is currently locked out.
Returns:
Remaining lockout seconds (> 0 means locked), 0 means not locked.
"""
try:
redis = await get_redis()
redis_key = f"lockout:{key}"
count = await redis.get(redis_key)
if count is not None and int(count) >= max_attempts:
ttl = await redis.ttl(redis_key)
return max(ttl, 0)
return 0
except Exception as exc:
logger.warning("lockout_check_redis_fallback", error=str(exc))
# ── Fallback ───────────────────────────────────────────────────────────
if key in _local_lockout_cache:
attempts, first_fail = _local_lockout_cache[key]
if attempts >= max_attempts:
remaining = int(lockout_seconds - (time.time() - first_fail))
if remaining > 0:
return remaining
else:
del _local_lockout_cache[key]
return 0
async def clear_failed_attempts(key: str) -> None:
"""Clear lockout state on successful login."""
try:
redis = await get_redis()
await redis.delete(f"lockout:{key}")
except Exception as exc:
logger.warning("lockout_clear_redis_fallback", error=str(exc))
# Always clear local cache too
_local_lockout_cache.pop(key, None)