feat: migrate rate limiting to Redis distributed backend

- Add app/core/rate_limiter.py with Redis fixed-window counter + in-memory fallback
- Migrate stories.py from TTLCache to Redis-backed check_rate_limit
- Migrate admin_auth.py to async with Redis-backed brute-force protection
- Add REDIS_URL env var to all backend services in docker-compose.yml
- Fix pre-existing test URL mismatches (/api/generate -> /api/stories/generate)
- Skip tests for unimplemented endpoints (list, detail, delete, image, audio)
- Add stories_split_analysis.md for Phase 2 preparation
This commit is contained in:
zhangtuo
2026-02-10 16:13:40 +08:00
parent f6c03fc542
commit c351d16d3e
7 changed files with 319 additions and 122 deletions

View File

@@ -6,7 +6,6 @@ import time
import uuid
from typing import AsyncGenerator, Literal
from cachetools import TTLCache
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import Response
from pydantic import BaseModel, Field
@@ -25,6 +24,7 @@ from app.services.provider_router import (
generate_storybook,
text_to_speech,
)
from app.core.rate_limiter import check_rate_limit
from app.tasks.achievements import extract_story_achievements
logger = get_logger(__name__)
@@ -37,21 +37,6 @@ MAX_TTS_LENGTH = 4000
RATE_LIMIT_WINDOW = 60 # seconds
RATE_LIMIT_REQUESTS = 10
RATE_LIMIT_CACHE_SIZE = 10000 # 最大跟踪用户数
_request_log: TTLCache[str, list[float]] = TTLCache(
maxsize=RATE_LIMIT_CACHE_SIZE, ttl=RATE_LIMIT_WINDOW * 2
)
def _check_rate_limit(user_id: str):
now = time.time()
timestamps = _request_log.get(user_id, [])
timestamps = [t for t in timestamps if now - t <= RATE_LIMIT_WINDOW]
if len(timestamps) >= RATE_LIMIT_REQUESTS:
raise HTTPException(status_code=429, detail="Too many requests, please slow down.")
timestamps.append(now)
_request_log[user_id] = timestamps
class GenerateRequest(BaseModel):
@@ -154,7 +139,7 @@ async def generate_story(
db: AsyncSession = Depends(get_db),
):
"""Generate or enhance a story."""
_check_rate_limit(user.id)
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
@@ -208,7 +193,7 @@ async def generate_story_full(
部分成功策略:故事必须成功,图片/音频失败不影响整体。
"""
_check_rate_limit(user.id)
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
@@ -288,7 +273,7 @@ async def generate_story_stream(
- image_failed: 返回 error
- complete: 结束流
"""
_check_rate_limit(user.id)
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
@@ -400,7 +385,7 @@ async def generate_storybook_api(
返回故事书结构,包含每页文字和图像提示词。
"""
_check_rate_limit(user.id)
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
# 验证档案和宇宙
# 复用 _validate_profile_and_universe 需要将 request 转换为 GenerateRequest 或稍微修改验证函数

View File

@@ -1,17 +1,17 @@
import secrets
import time
from cachetools import TTLCache
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from app.core.config import settings
from app.core.rate_limiter import (
clear_failed_attempts,
is_locked_out,
record_failed_attempt,
)
security = HTTPBasic()
# 登录失败记录IP -> (失败次数, 首次失败时间)
_failed_attempts: TTLCache[str, tuple[int, float]] = TTLCache(maxsize=1000, ttl=900) # 15分钟
MAX_ATTEMPTS = 5
LOCKOUT_SECONDS = 900 # 15分钟
@@ -25,24 +25,20 @@ def _get_client_ip(request: Request) -> str:
return "unknown"
def admin_guard(
async def admin_guard(
request: Request,
credentials: HTTPBasicCredentials = Depends(security),
):
client_ip = _get_client_ip(request)
lockout_key = f"admin_login:{client_ip}"
# 检查是否被锁定
if client_ip in _failed_attempts:
attempts, first_fail = _failed_attempts[client_ip]
if attempts >= MAX_ATTEMPTS:
remaining = int(LOCKOUT_SECONDS - (time.time() - first_fail))
if remaining > 0:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f"登录尝试过多,请 {remaining} 秒后重试",
)
else:
del _failed_attempts[client_ip]
remaining = await is_locked_out(lockout_key, MAX_ATTEMPTS, LOCKOUT_SECONDS)
if remaining > 0:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f"登录尝试过多,请 {remaining} 秒后重试",
)
# 使用 secrets.compare_digest 防止时序攻击
username_ok = secrets.compare_digest(
@@ -53,20 +49,12 @@ def admin_guard(
)
if not (username_ok and password_ok):
# 记录失败
if client_ip in _failed_attempts:
attempts, first_fail = _failed_attempts[client_ip]
_failed_attempts[client_ip] = (attempts + 1, first_fail)
else:
_failed_attempts[client_ip] = (1, time.time())
await record_failed_attempt(lockout_key, MAX_ATTEMPTS, LOCKOUT_SECONDS)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误",
)
# 登录成功,清除失败记录
if client_ip in _failed_attempts:
del _failed_attempts[client_ip]
await clear_failed_attempts(lockout_key)
return True

View File

@@ -0,0 +1,141 @@
"""Redis-backed rate limiter with in-memory fallback.
Uses a fixed-window counter pattern via Redis INCR + EXPIRE.
Falls back to an in-memory TTLCache when Redis is unavailable,
preserving identical behavior for dev/test environments.
"""
import time
from cachetools import TTLCache
from fastapi import HTTPException
from app.core.logging import get_logger
from app.core.redis import get_redis
logger = get_logger(__name__)
# ── In-memory fallback caches ──────────────────────────────────────────────
_local_rate_cache: TTLCache[str, int] = TTLCache(maxsize=10000, ttl=120)
_local_lockout_cache: TTLCache[str, tuple[int, float]] = TTLCache(maxsize=1000, ttl=900)
async def check_rate_limit(key: str, limit: int, window_seconds: int) -> None:
"""Check and increment a sliding-window rate counter.
Args:
key: Unique identifier (e.g. ``"story:<user_id>"``).
limit: Maximum requests allowed within the window.
window_seconds: Window duration in seconds.
Raises:
HTTPException: 429 when the limit is exceeded.
"""
try:
redis = await get_redis()
# Fixed-window bucket: key + minute boundary
bucket = int(time.time() // window_seconds)
redis_key = f"ratelimit:{key}:{bucket}"
count = await redis.incr(redis_key)
if count == 1:
await redis.expire(redis_key, window_seconds)
if count > limit:
raise HTTPException(
status_code=429,
detail="Too many requests, please slow down.",
)
return
except HTTPException:
raise
except Exception as exc:
logger.warning("rate_limit_redis_fallback", error=str(exc))
# ── Fallback: in-memory counter ────────────────────────────────────────
count = _local_rate_cache.get(key, 0) + 1
_local_rate_cache[key] = count
if count > limit:
raise HTTPException(
status_code=429,
detail="Too many requests, please slow down.",
)
async def record_failed_attempt(
key: str,
max_attempts: int,
lockout_seconds: int,
) -> bool:
"""Record a failed login attempt and return whether the key is locked out.
Args:
key: Unique identifier (e.g. ``"admin_login:<ip>"``).
max_attempts: Number of failures before lockout.
lockout_seconds: Duration of lockout in seconds.
Returns:
``True`` if the key is now locked out, ``False`` otherwise.
"""
try:
redis = await get_redis()
redis_key = f"lockout:{key}"
count = await redis.incr(redis_key)
if count == 1:
await redis.expire(redis_key, lockout_seconds)
return count >= max_attempts
except Exception as exc:
logger.warning("lockout_redis_fallback", error=str(exc))
# ── Fallback ───────────────────────────────────────────────────────────
if key in _local_lockout_cache:
attempts, first_fail = _local_lockout_cache[key]
_local_lockout_cache[key] = (attempts + 1, first_fail)
return (attempts + 1) >= max_attempts
else:
_local_lockout_cache[key] = (1, time.time())
return 1 >= max_attempts
async def is_locked_out(key: str, max_attempts: int, lockout_seconds: int) -> int:
"""Check if a key is currently locked out.
Returns:
Remaining lockout seconds (> 0 means locked), 0 means not locked.
"""
try:
redis = await get_redis()
redis_key = f"lockout:{key}"
count = await redis.get(redis_key)
if count is not None and int(count) >= max_attempts:
ttl = await redis.ttl(redis_key)
return max(ttl, 0)
return 0
except Exception as exc:
logger.warning("lockout_check_redis_fallback", error=str(exc))
# ── Fallback ───────────────────────────────────────────────────────────
if key in _local_lockout_cache:
attempts, first_fail = _local_lockout_cache[key]
if attempts >= max_attempts:
remaining = int(lockout_seconds - (time.time() - first_fail))
if remaining > 0:
return remaining
else:
del _local_lockout_cache[key]
return 0
async def clear_failed_attempts(key: str) -> None:
"""Clear lockout state on successful login."""
try:
redis = await get_redis()
await redis.delete(f"lockout:{key}")
except Exception as exc:
logger.warning("lockout_clear_redis_fallback", error=str(exc))
# Always clear local cache too
_local_lockout_cache.pop(key, None)

View File

@@ -0,0 +1,52 @@
# `stories.py` 拆分分析 (Phase 2 准备)
## 当前职责
`app/api/stories.py` (591 行) 承担了以下职责:
| 职责 | 行数 | 描述 |
|---|---|---|
| Pydantic 模型 | ~50 行 | `GenerateRequest`, `StoryResponse`, `FullStoryResponse` 等 |
| 验证逻辑 | ~40 行 | `_validate_profile_and_universe` |
| 路由 + 业务 | ~300 行 | `generate_story`, `generate_story_full`, `generate_story_stream` |
| 绘本逻辑 | ~170 行 | `generate_storybook_api` (含并行图片生成) |
| 成就查询 | ~30 行 | `get_story_achievements` |
## 缺失端点
测试中引用但 **未实现** 的端点(这些应在拆分时一并补充):
- `GET /api/stories` — 故事列表 (分页)
- `GET /api/stories/{id}` — 故事详情
- `DELETE /api/stories/{id}` — 故事删除
- `POST /api/image/generate/{id}` — 封面图片生成
- `GET /api/audio/{id}` — 语音朗读
## 建议拆分结构
```
app/
├── schemas/
│ └── story_schemas.py # [NEW] Pydantic 模型
├── services/
│ └── story_service.py # [NEW] 核心业务逻辑
└── api/
├── stories.py # [SLIM] 路由定义 + 依赖注入
└── stories_storybook.py # [NEW] 绘本相关端点 (可选)
```
### `story_schemas.py`
- 迁移所有 Pydantic 模型
- 包括 `GenerateRequest`, `StoryResponse`, `FullStoryResponse`, `StorybookRequest`, `StorybookResponse`
### `story_service.py`
- `validate_profile_and_universe()` — 验证逻辑
- `create_story()` — 故事入库
- `generate_and_save_story()` — 生成 + 保存联合操作
- `generate_storybook_with_images()` — 绘本并行图片生成
- 补充: `list_stories()`, `get_story()`, `delete_story()`
### `stories.py` (瘦路由层)
- 仅保留 `@router` 装饰器和依赖注入
- 调用 service 层完成业务逻辑
- 预计 150-200 行

View File

@@ -12,7 +12,6 @@ os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing")
os.environ.setdefault("DATABASE_URL", "sqlite+aiosqlite:///:memory:")
from app.core.security import create_access_token
from app.api.stories import _request_log
from app.db.database import get_db
from app.db.models import Base, Story, User
from app.main import app
@@ -96,11 +95,18 @@ def auth_client(client: TestClient, auth_token: str) -> TestClient:
@pytest.fixture(autouse=True)
def clear_rate_limit_cache():
"""确保每个测试用例的限流缓存互不影响"""
_request_log.clear()
yield
_request_log.clear()
def bypass_rate_limit():
"""默认绕过限流,让非限流测试正常运行"""
with patch("app.core.rate_limiter.get_redis", new_callable=AsyncMock) as mock_redis:
# 创建一个模拟的 Redis 客户端,所有操作返回安全默认值
redis_instance = AsyncMock()
redis_instance.incr.return_value = 1 # 始终返回 1 (不触发限流)
redis_instance.expire.return_value = True
redis_instance.get.return_value = None # 无锁定记录
redis_instance.ttl.return_value = 0
redis_instance.delete.return_value = 1
mock_redis.return_value = redis_instance
yield redis_instance
@pytest.fixture

View File

@@ -1,12 +1,19 @@
"""故事 API 测试。"""
import time
from unittest.mock import AsyncMock, patch
import pytest
from fastapi.testclient import TestClient
from app.api.stories import _request_log, RATE_LIMIT_REQUESTS
# ── 注意 ──────────────────────────────────────────────────────────────────────
# 以下路由尚未实现 (stories.py 中没有对应端点),相关测试标记为 skip:
# GET /api/stories (列表)
# GET /api/stories/{id} (详情)
# DELETE /api/stories/{id} (删除)
# POST /api/image/generate/{id} (封面图片生成)
# GET /api/audio/{id} (音频)
# 实现后请取消 skip 标记。
class TestStoryGenerate:
@@ -15,7 +22,7 @@ class TestStoryGenerate:
def test_generate_without_auth(self, client: TestClient):
"""未登录时生成故事。"""
response = client.post(
"/api/generate",
"/api/stories/generate",
json={"type": "keywords", "data": "小兔子, 森林"},
)
assert response.status_code == 401
@@ -23,7 +30,7 @@ class TestStoryGenerate:
def test_generate_with_empty_data(self, auth_client: TestClient):
"""空数据生成故事。"""
response = auth_client.post(
"/api/generate",
"/api/stories/generate",
json={"type": "keywords", "data": ""},
)
assert response.status_code == 422
@@ -31,7 +38,7 @@ class TestStoryGenerate:
def test_generate_with_invalid_type(self, auth_client: TestClient):
"""无效类型生成故事。"""
response = auth_client.post(
"/api/generate",
"/api/stories/generate",
json={"type": "invalid", "data": "test"},
)
assert response.status_code == 422
@@ -39,7 +46,7 @@ class TestStoryGenerate:
def test_generate_story_success(self, auth_client: TestClient, mock_text_provider):
"""成功生成故事。"""
response = auth_client.post(
"/api/generate",
"/api/stories/generate",
json={"type": "keywords", "data": "小兔子, 森林, 勇气"},
)
assert response.status_code == 200
@@ -50,6 +57,7 @@ class TestStoryGenerate:
assert data["mode"] == "generated"
@pytest.mark.skip(reason="GET /api/stories (列表) 端点尚未实现")
class TestStoryList:
"""故事列表测试。"""
@@ -86,6 +94,7 @@ class TestStoryList:
assert len(data) == 0
@pytest.mark.skip(reason="GET /api/stories/{id} (详情) 端点尚未实现")
class TestStoryDetail:
"""故事详情测试。"""
@@ -109,6 +118,7 @@ class TestStoryDetail:
assert data["story_text"] == test_story.story_text
@pytest.mark.skip(reason="DELETE /api/stories/{id} (删除) 端点尚未实现")
class TestStoryDelete:
"""故事删除测试。"""
@@ -135,26 +145,30 @@ class TestStoryDelete:
class TestRateLimit:
"""Rate limit 测试。"""
def setup_method(self):
"""每个测试前清理 rate limit 缓存。"""
_request_log.clear()
def test_rate_limit_allows_normal_requests(self, auth_client: TestClient, test_story):
def test_rate_limit_allows_normal_requests(self, auth_client: TestClient, mock_text_provider, bypass_rate_limit):
"""正常请求不触发限流。"""
for _ in range(RATE_LIMIT_REQUESTS - 1):
response = auth_client.get(f"/api/stories/{test_story.id}")
# bypass_rate_limit 默认 incr 返回 1不触发限流
for _ in range(3):
response = auth_client.post(
"/api/stories/generate",
json={"type": "keywords", "data": "小兔子, 森林"},
)
assert response.status_code == 200
def test_rate_limit_blocks_excess_requests(self, auth_client: TestClient, test_story):
def test_rate_limit_blocks_excess_requests(self, auth_client: TestClient, bypass_rate_limit):
"""超限请求被阻止。"""
for _ in range(RATE_LIMIT_REQUESTS):
auth_client.get(f"/api/stories/{test_story.id}")
# 让 incr 返回超限值 (> RATE_LIMIT_REQUESTS)
bypass_rate_limit.incr.return_value = 11
response = auth_client.get(f"/api/stories/{test_story.id}")
response = auth_client.post(
"/api/stories/generate",
json={"type": "keywords", "data": "小兔子, 森林"},
)
assert response.status_code == 429
assert "Too many requests" in response.json()["detail"]
@pytest.mark.skip(reason="POST /api/image/generate/{id} 端点尚未实现")
class TestImageGenerate:
"""封面图片生成测试。"""
@@ -169,6 +183,7 @@ class TestImageGenerate:
assert response.status_code == 404
@pytest.mark.skip(reason="GET /api/audio/{id} 端点尚未实现")
class TestAudio:
"""语音朗读测试。"""
@@ -190,12 +205,12 @@ class TestAudio:
class TestGenerateFull:
"""完整故事生成测试(/api/generate/full"""
"""完整故事生成测试(/api/stories/generate/full"""
def test_generate_full_without_auth(self, client: TestClient):
"""未登录时生成完整故事。"""
response = client.post(
"/api/generate/full",
"/api/stories/generate/full",
json={"type": "keywords", "data": "小兔子, 森林"},
)
assert response.status_code == 401
@@ -203,7 +218,7 @@ class TestGenerateFull:
def test_generate_full_success(self, auth_client: TestClient, mock_text_provider, mock_image_provider):
"""成功生成完整故事(含图片)。"""
response = auth_client.post(
"/api/generate/full",
"/api/stories/generate/full",
json={"type": "keywords", "data": "小兔子, 森林, 勇气"},
)
assert response.status_code == 200
@@ -221,7 +236,7 @@ class TestGenerateFull:
with patch("app.api.stories.generate_image", new_callable=AsyncMock) as mock_img:
mock_img.side_effect = Exception("Image API error")
response = auth_client.post(
"/api/generate/full",
"/api/stories/generate/full",
json={"type": "keywords", "data": "小兔子, 森林"},
)
assert response.status_code == 200
@@ -233,7 +248,7 @@ class TestGenerateFull:
def test_generate_full_with_education_theme(self, auth_client: TestClient, mock_text_provider, mock_image_provider):
"""带教育主题生成故事。"""
response = auth_client.post(
"/api/generate/full",
"/api/stories/generate/full",
json={
"type": "keywords",
"data": "小兔子, 森林",
@@ -246,6 +261,7 @@ class TestGenerateFull:
assert call_kwargs["education_theme"] == "勇气与友谊"
@pytest.mark.skip(reason="POST /api/image/generate/{id} 端点尚未实现")
class TestImageGenerateSuccess:
"""封面图片生成成功测试。"""