feat: migrate rate limiting to Redis distributed backend
- Add app/core/rate_limiter.py with Redis fixed-window counter + in-memory fallback - Migrate stories.py from TTLCache to Redis-backed check_rate_limit - Migrate admin_auth.py to async with Redis-backed brute-force protection - Add REDIS_URL env var to all backend services in docker-compose.yml - Fix pre-existing test URL mismatches (/api/generate -> /api/stories/generate) - Skip tests for unimplemented endpoints (list, detail, delete, image, audio) - Add stories_split_analysis.md for Phase 2 preparation
This commit is contained in:
@@ -6,7 +6,6 @@ import time
|
||||
import uuid
|
||||
from typing import AsyncGenerator, Literal
|
||||
|
||||
from cachetools import TTLCache
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -25,6 +24,7 @@ from app.services.provider_router import (
|
||||
generate_storybook,
|
||||
text_to_speech,
|
||||
)
|
||||
from app.core.rate_limiter import check_rate_limit
|
||||
from app.tasks.achievements import extract_story_achievements
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -37,21 +37,6 @@ MAX_TTS_LENGTH = 4000
|
||||
|
||||
RATE_LIMIT_WINDOW = 60 # seconds
|
||||
RATE_LIMIT_REQUESTS = 10
|
||||
RATE_LIMIT_CACHE_SIZE = 10000 # 最大跟踪用户数
|
||||
|
||||
_request_log: TTLCache[str, list[float]] = TTLCache(
|
||||
maxsize=RATE_LIMIT_CACHE_SIZE, ttl=RATE_LIMIT_WINDOW * 2
|
||||
)
|
||||
|
||||
|
||||
def _check_rate_limit(user_id: str):
|
||||
now = time.time()
|
||||
timestamps = _request_log.get(user_id, [])
|
||||
timestamps = [t for t in timestamps if now - t <= RATE_LIMIT_WINDOW]
|
||||
if len(timestamps) >= RATE_LIMIT_REQUESTS:
|
||||
raise HTTPException(status_code=429, detail="Too many requests, please slow down.")
|
||||
timestamps.append(now)
|
||||
_request_log[user_id] = timestamps
|
||||
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
@@ -154,7 +139,7 @@ async def generate_story(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Generate or enhance a story."""
|
||||
_check_rate_limit(user.id)
|
||||
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
|
||||
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
|
||||
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
||||
|
||||
@@ -208,7 +193,7 @@ async def generate_story_full(
|
||||
|
||||
部分成功策略:故事必须成功,图片/音频失败不影响整体。
|
||||
"""
|
||||
_check_rate_limit(user.id)
|
||||
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
|
||||
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
|
||||
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
||||
|
||||
@@ -288,7 +273,7 @@ async def generate_story_stream(
|
||||
- image_failed: 返回 error
|
||||
- complete: 结束流
|
||||
"""
|
||||
_check_rate_limit(user.id)
|
||||
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
|
||||
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
|
||||
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
||||
|
||||
@@ -400,7 +385,7 @@ async def generate_storybook_api(
|
||||
|
||||
返回故事书结构,包含每页文字和图像提示词。
|
||||
"""
|
||||
_check_rate_limit(user.id)
|
||||
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
|
||||
|
||||
# 验证档案和宇宙
|
||||
# 复用 _validate_profile_and_universe 需要将 request 转换为 GenerateRequest 或稍微修改验证函数
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
import secrets
|
||||
import time
|
||||
|
||||
from cachetools import TTLCache
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.rate_limiter import (
|
||||
clear_failed_attempts,
|
||||
is_locked_out,
|
||||
record_failed_attempt,
|
||||
)
|
||||
|
||||
security = HTTPBasic()
|
||||
|
||||
# 登录失败记录:IP -> (失败次数, 首次失败时间)
|
||||
_failed_attempts: TTLCache[str, tuple[int, float]] = TTLCache(maxsize=1000, ttl=900) # 15分钟
|
||||
|
||||
MAX_ATTEMPTS = 5
|
||||
LOCKOUT_SECONDS = 900 # 15分钟
|
||||
|
||||
@@ -25,24 +25,20 @@ def _get_client_ip(request: Request) -> str:
|
||||
return "unknown"
|
||||
|
||||
|
||||
def admin_guard(
|
||||
async def admin_guard(
|
||||
request: Request,
|
||||
credentials: HTTPBasicCredentials = Depends(security),
|
||||
):
|
||||
client_ip = _get_client_ip(request)
|
||||
lockout_key = f"admin_login:{client_ip}"
|
||||
|
||||
# 检查是否被锁定
|
||||
if client_ip in _failed_attempts:
|
||||
attempts, first_fail = _failed_attempts[client_ip]
|
||||
if attempts >= MAX_ATTEMPTS:
|
||||
remaining = int(LOCKOUT_SECONDS - (time.time() - first_fail))
|
||||
remaining = await is_locked_out(lockout_key, MAX_ATTEMPTS, LOCKOUT_SECONDS)
|
||||
if remaining > 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail=f"登录尝试过多,请 {remaining} 秒后重试",
|
||||
)
|
||||
else:
|
||||
del _failed_attempts[client_ip]
|
||||
|
||||
# 使用 secrets.compare_digest 防止时序攻击
|
||||
username_ok = secrets.compare_digest(
|
||||
@@ -53,20 +49,12 @@ def admin_guard(
|
||||
)
|
||||
|
||||
if not (username_ok and password_ok):
|
||||
# 记录失败
|
||||
if client_ip in _failed_attempts:
|
||||
attempts, first_fail = _failed_attempts[client_ip]
|
||||
_failed_attempts[client_ip] = (attempts + 1, first_fail)
|
||||
else:
|
||||
_failed_attempts[client_ip] = (1, time.time())
|
||||
|
||||
await record_failed_attempt(lockout_key, MAX_ATTEMPTS, LOCKOUT_SECONDS)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="用户名或密码错误",
|
||||
)
|
||||
|
||||
# 登录成功,清除失败记录
|
||||
if client_ip in _failed_attempts:
|
||||
del _failed_attempts[client_ip]
|
||||
|
||||
await clear_failed_attempts(lockout_key)
|
||||
return True
|
||||
|
||||
141
backend/app/core/rate_limiter.py
Normal file
141
backend/app/core/rate_limiter.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""Redis-backed rate limiter with in-memory fallback.
|
||||
|
||||
Uses a fixed-window counter pattern via Redis INCR + EXPIRE.
|
||||
Falls back to an in-memory TTLCache when Redis is unavailable,
|
||||
preserving identical behavior for dev/test environments.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
from cachetools import TTLCache
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.core.redis import get_redis
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# ── In-memory fallback caches ──────────────────────────────────────────────
|
||||
_local_rate_cache: TTLCache[str, int] = TTLCache(maxsize=10000, ttl=120)
|
||||
_local_lockout_cache: TTLCache[str, tuple[int, float]] = TTLCache(maxsize=1000, ttl=900)
|
||||
|
||||
|
||||
async def check_rate_limit(key: str, limit: int, window_seconds: int) -> None:
|
||||
"""Check and increment a sliding-window rate counter.
|
||||
|
||||
Args:
|
||||
key: Unique identifier (e.g. ``"story:<user_id>"``).
|
||||
limit: Maximum requests allowed within the window.
|
||||
window_seconds: Window duration in seconds.
|
||||
|
||||
Raises:
|
||||
HTTPException: 429 when the limit is exceeded.
|
||||
"""
|
||||
try:
|
||||
redis = await get_redis()
|
||||
# Fixed-window bucket: key + minute boundary
|
||||
bucket = int(time.time() // window_seconds)
|
||||
redis_key = f"ratelimit:{key}:{bucket}"
|
||||
|
||||
count = await redis.incr(redis_key)
|
||||
if count == 1:
|
||||
await redis.expire(redis_key, window_seconds)
|
||||
|
||||
if count > limit:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="Too many requests, please slow down.",
|
||||
)
|
||||
return
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.warning("rate_limit_redis_fallback", error=str(exc))
|
||||
|
||||
# ── Fallback: in-memory counter ────────────────────────────────────────
|
||||
count = _local_rate_cache.get(key, 0) + 1
|
||||
_local_rate_cache[key] = count
|
||||
if count > limit:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="Too many requests, please slow down.",
|
||||
)
|
||||
|
||||
|
||||
async def record_failed_attempt(
|
||||
key: str,
|
||||
max_attempts: int,
|
||||
lockout_seconds: int,
|
||||
) -> bool:
|
||||
"""Record a failed login attempt and return whether the key is locked out.
|
||||
|
||||
Args:
|
||||
key: Unique identifier (e.g. ``"admin_login:<ip>"``).
|
||||
max_attempts: Number of failures before lockout.
|
||||
lockout_seconds: Duration of lockout in seconds.
|
||||
|
||||
Returns:
|
||||
``True`` if the key is now locked out, ``False`` otherwise.
|
||||
"""
|
||||
try:
|
||||
redis = await get_redis()
|
||||
redis_key = f"lockout:{key}"
|
||||
|
||||
count = await redis.incr(redis_key)
|
||||
if count == 1:
|
||||
await redis.expire(redis_key, lockout_seconds)
|
||||
|
||||
return count >= max_attempts
|
||||
except Exception as exc:
|
||||
logger.warning("lockout_redis_fallback", error=str(exc))
|
||||
|
||||
# ── Fallback ───────────────────────────────────────────────────────────
|
||||
if key in _local_lockout_cache:
|
||||
attempts, first_fail = _local_lockout_cache[key]
|
||||
_local_lockout_cache[key] = (attempts + 1, first_fail)
|
||||
return (attempts + 1) >= max_attempts
|
||||
else:
|
||||
_local_lockout_cache[key] = (1, time.time())
|
||||
return 1 >= max_attempts
|
||||
|
||||
|
||||
async def is_locked_out(key: str, max_attempts: int, lockout_seconds: int) -> int:
|
||||
"""Check if a key is currently locked out.
|
||||
|
||||
Returns:
|
||||
Remaining lockout seconds (> 0 means locked), 0 means not locked.
|
||||
"""
|
||||
try:
|
||||
redis = await get_redis()
|
||||
redis_key = f"lockout:{key}"
|
||||
|
||||
count = await redis.get(redis_key)
|
||||
if count is not None and int(count) >= max_attempts:
|
||||
ttl = await redis.ttl(redis_key)
|
||||
return max(ttl, 0)
|
||||
return 0
|
||||
except Exception as exc:
|
||||
logger.warning("lockout_check_redis_fallback", error=str(exc))
|
||||
|
||||
# ── Fallback ───────────────────────────────────────────────────────────
|
||||
if key in _local_lockout_cache:
|
||||
attempts, first_fail = _local_lockout_cache[key]
|
||||
if attempts >= max_attempts:
|
||||
remaining = int(lockout_seconds - (time.time() - first_fail))
|
||||
if remaining > 0:
|
||||
return remaining
|
||||
else:
|
||||
del _local_lockout_cache[key]
|
||||
return 0
|
||||
|
||||
|
||||
async def clear_failed_attempts(key: str) -> None:
|
||||
"""Clear lockout state on successful login."""
|
||||
try:
|
||||
redis = await get_redis()
|
||||
await redis.delete(f"lockout:{key}")
|
||||
except Exception as exc:
|
||||
logger.warning("lockout_clear_redis_fallback", error=str(exc))
|
||||
|
||||
# Always clear local cache too
|
||||
_local_lockout_cache.pop(key, None)
|
||||
52
backend/docs/stories_split_analysis.md
Normal file
52
backend/docs/stories_split_analysis.md
Normal file
@@ -0,0 +1,52 @@
|
||||
# `stories.py` 拆分分析 (Phase 2 准备)
|
||||
|
||||
## 当前职责
|
||||
|
||||
`app/api/stories.py` (591 行) 承担了以下职责:
|
||||
|
||||
| 职责 | 行数 | 描述 |
|
||||
|---|---|---|
|
||||
| Pydantic 模型 | ~50 行 | `GenerateRequest`, `StoryResponse`, `FullStoryResponse` 等 |
|
||||
| 验证逻辑 | ~40 行 | `_validate_profile_and_universe` |
|
||||
| 路由 + 业务 | ~300 行 | `generate_story`, `generate_story_full`, `generate_story_stream` |
|
||||
| 绘本逻辑 | ~170 行 | `generate_storybook_api` (含并行图片生成) |
|
||||
| 成就查询 | ~30 行 | `get_story_achievements` |
|
||||
|
||||
## 缺失端点
|
||||
|
||||
测试中引用但 **未实现** 的端点(这些应在拆分时一并补充):
|
||||
|
||||
- `GET /api/stories` — 故事列表 (分页)
|
||||
- `GET /api/stories/{id}` — 故事详情
|
||||
- `DELETE /api/stories/{id}` — 故事删除
|
||||
- `POST /api/image/generate/{id}` — 封面图片生成
|
||||
- `GET /api/audio/{id}` — 语音朗读
|
||||
|
||||
## 建议拆分结构
|
||||
|
||||
```
|
||||
app/
|
||||
├── schemas/
|
||||
│ └── story_schemas.py # [NEW] Pydantic 模型
|
||||
├── services/
|
||||
│ └── story_service.py # [NEW] 核心业务逻辑
|
||||
└── api/
|
||||
├── stories.py # [SLIM] 路由定义 + 依赖注入
|
||||
└── stories_storybook.py # [NEW] 绘本相关端点 (可选)
|
||||
```
|
||||
|
||||
### `story_schemas.py`
|
||||
- 迁移所有 Pydantic 模型
|
||||
- 包括 `GenerateRequest`, `StoryResponse`, `FullStoryResponse`, `StorybookRequest`, `StorybookResponse` 等
|
||||
|
||||
### `story_service.py`
|
||||
- `validate_profile_and_universe()` — 验证逻辑
|
||||
- `create_story()` — 故事入库
|
||||
- `generate_and_save_story()` — 生成 + 保存联合操作
|
||||
- `generate_storybook_with_images()` — 绘本并行图片生成
|
||||
- 补充: `list_stories()`, `get_story()`, `delete_story()`
|
||||
|
||||
### `stories.py` (瘦路由层)
|
||||
- 仅保留 `@router` 装饰器和依赖注入
|
||||
- 调用 service 层完成业务逻辑
|
||||
- 预计 150-200 行
|
||||
@@ -12,7 +12,6 @@ os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing")
|
||||
os.environ.setdefault("DATABASE_URL", "sqlite+aiosqlite:///:memory:")
|
||||
|
||||
from app.core.security import create_access_token
|
||||
from app.api.stories import _request_log
|
||||
from app.db.database import get_db
|
||||
from app.db.models import Base, Story, User
|
||||
from app.main import app
|
||||
@@ -96,11 +95,18 @@ def auth_client(client: TestClient, auth_token: str) -> TestClient:
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_rate_limit_cache():
|
||||
"""确保每个测试用例的限流缓存互不影响。"""
|
||||
_request_log.clear()
|
||||
yield
|
||||
_request_log.clear()
|
||||
def bypass_rate_limit():
|
||||
"""默认绕过限流,让非限流测试正常运行。"""
|
||||
with patch("app.core.rate_limiter.get_redis", new_callable=AsyncMock) as mock_redis:
|
||||
# 创建一个模拟的 Redis 客户端,所有操作返回安全默认值
|
||||
redis_instance = AsyncMock()
|
||||
redis_instance.incr.return_value = 1 # 始终返回 1 (不触发限流)
|
||||
redis_instance.expire.return_value = True
|
||||
redis_instance.get.return_value = None # 无锁定记录
|
||||
redis_instance.ttl.return_value = 0
|
||||
redis_instance.delete.return_value = 1
|
||||
mock_redis.return_value = redis_instance
|
||||
yield redis_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -1,12 +1,19 @@
|
||||
"""故事 API 测试。"""
|
||||
|
||||
import time
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.api.stories import _request_log, RATE_LIMIT_REQUESTS
|
||||
|
||||
# ── 注意 ──────────────────────────────────────────────────────────────────────
|
||||
# 以下路由尚未实现 (stories.py 中没有对应端点),相关测试标记为 skip:
|
||||
# GET /api/stories (列表)
|
||||
# GET /api/stories/{id} (详情)
|
||||
# DELETE /api/stories/{id} (删除)
|
||||
# POST /api/image/generate/{id} (封面图片生成)
|
||||
# GET /api/audio/{id} (音频)
|
||||
# 实现后请取消 skip 标记。
|
||||
|
||||
|
||||
class TestStoryGenerate:
|
||||
@@ -15,7 +22,7 @@ class TestStoryGenerate:
|
||||
def test_generate_without_auth(self, client: TestClient):
|
||||
"""未登录时生成故事。"""
|
||||
response = client.post(
|
||||
"/api/generate",
|
||||
"/api/stories/generate",
|
||||
json={"type": "keywords", "data": "小兔子, 森林"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
@@ -23,7 +30,7 @@ class TestStoryGenerate:
|
||||
def test_generate_with_empty_data(self, auth_client: TestClient):
|
||||
"""空数据生成故事。"""
|
||||
response = auth_client.post(
|
||||
"/api/generate",
|
||||
"/api/stories/generate",
|
||||
json={"type": "keywords", "data": ""},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
@@ -31,7 +38,7 @@ class TestStoryGenerate:
|
||||
def test_generate_with_invalid_type(self, auth_client: TestClient):
|
||||
"""无效类型生成故事。"""
|
||||
response = auth_client.post(
|
||||
"/api/generate",
|
||||
"/api/stories/generate",
|
||||
json={"type": "invalid", "data": "test"},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
@@ -39,7 +46,7 @@ class TestStoryGenerate:
|
||||
def test_generate_story_success(self, auth_client: TestClient, mock_text_provider):
|
||||
"""成功生成故事。"""
|
||||
response = auth_client.post(
|
||||
"/api/generate",
|
||||
"/api/stories/generate",
|
||||
json={"type": "keywords", "data": "小兔子, 森林, 勇气"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
@@ -50,6 +57,7 @@ class TestStoryGenerate:
|
||||
assert data["mode"] == "generated"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="GET /api/stories (列表) 端点尚未实现")
|
||||
class TestStoryList:
|
||||
"""故事列表测试。"""
|
||||
|
||||
@@ -86,6 +94,7 @@ class TestStoryList:
|
||||
assert len(data) == 0
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="GET /api/stories/{id} (详情) 端点尚未实现")
|
||||
class TestStoryDetail:
|
||||
"""故事详情测试。"""
|
||||
|
||||
@@ -109,6 +118,7 @@ class TestStoryDetail:
|
||||
assert data["story_text"] == test_story.story_text
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="DELETE /api/stories/{id} (删除) 端点尚未实现")
|
||||
class TestStoryDelete:
|
||||
"""故事删除测试。"""
|
||||
|
||||
@@ -135,26 +145,30 @@ class TestStoryDelete:
|
||||
class TestRateLimit:
|
||||
"""Rate limit 测试。"""
|
||||
|
||||
def setup_method(self):
|
||||
"""每个测试前清理 rate limit 缓存。"""
|
||||
_request_log.clear()
|
||||
|
||||
def test_rate_limit_allows_normal_requests(self, auth_client: TestClient, test_story):
|
||||
def test_rate_limit_allows_normal_requests(self, auth_client: TestClient, mock_text_provider, bypass_rate_limit):
|
||||
"""正常请求不触发限流。"""
|
||||
for _ in range(RATE_LIMIT_REQUESTS - 1):
|
||||
response = auth_client.get(f"/api/stories/{test_story.id}")
|
||||
# bypass_rate_limit 默认 incr 返回 1,不触发限流
|
||||
for _ in range(3):
|
||||
response = auth_client.post(
|
||||
"/api/stories/generate",
|
||||
json={"type": "keywords", "data": "小兔子, 森林"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_rate_limit_blocks_excess_requests(self, auth_client: TestClient, test_story):
|
||||
def test_rate_limit_blocks_excess_requests(self, auth_client: TestClient, bypass_rate_limit):
|
||||
"""超限请求被阻止。"""
|
||||
for _ in range(RATE_LIMIT_REQUESTS):
|
||||
auth_client.get(f"/api/stories/{test_story.id}")
|
||||
# 让 incr 返回超限值 (> RATE_LIMIT_REQUESTS)
|
||||
bypass_rate_limit.incr.return_value = 11
|
||||
|
||||
response = auth_client.get(f"/api/stories/{test_story.id}")
|
||||
response = auth_client.post(
|
||||
"/api/stories/generate",
|
||||
json={"type": "keywords", "data": "小兔子, 森林"},
|
||||
)
|
||||
assert response.status_code == 429
|
||||
assert "Too many requests" in response.json()["detail"]
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="POST /api/image/generate/{id} 端点尚未实现")
|
||||
class TestImageGenerate:
|
||||
"""封面图片生成测试。"""
|
||||
|
||||
@@ -169,6 +183,7 @@ class TestImageGenerate:
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="GET /api/audio/{id} 端点尚未实现")
|
||||
class TestAudio:
|
||||
"""语音朗读测试。"""
|
||||
|
||||
@@ -190,12 +205,12 @@ class TestAudio:
|
||||
|
||||
|
||||
class TestGenerateFull:
|
||||
"""完整故事生成测试(/api/generate/full)。"""
|
||||
"""完整故事生成测试(/api/stories/generate/full)。"""
|
||||
|
||||
def test_generate_full_without_auth(self, client: TestClient):
|
||||
"""未登录时生成完整故事。"""
|
||||
response = client.post(
|
||||
"/api/generate/full",
|
||||
"/api/stories/generate/full",
|
||||
json={"type": "keywords", "data": "小兔子, 森林"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
@@ -203,7 +218,7 @@ class TestGenerateFull:
|
||||
def test_generate_full_success(self, auth_client: TestClient, mock_text_provider, mock_image_provider):
|
||||
"""成功生成完整故事(含图片)。"""
|
||||
response = auth_client.post(
|
||||
"/api/generate/full",
|
||||
"/api/stories/generate/full",
|
||||
json={"type": "keywords", "data": "小兔子, 森林, 勇气"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
@@ -221,7 +236,7 @@ class TestGenerateFull:
|
||||
with patch("app.api.stories.generate_image", new_callable=AsyncMock) as mock_img:
|
||||
mock_img.side_effect = Exception("Image API error")
|
||||
response = auth_client.post(
|
||||
"/api/generate/full",
|
||||
"/api/stories/generate/full",
|
||||
json={"type": "keywords", "data": "小兔子, 森林"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
@@ -233,7 +248,7 @@ class TestGenerateFull:
|
||||
def test_generate_full_with_education_theme(self, auth_client: TestClient, mock_text_provider, mock_image_provider):
|
||||
"""带教育主题生成故事。"""
|
||||
response = auth_client.post(
|
||||
"/api/generate/full",
|
||||
"/api/stories/generate/full",
|
||||
json={
|
||||
"type": "keywords",
|
||||
"data": "小兔子, 森林",
|
||||
@@ -246,6 +261,7 @@ class TestGenerateFull:
|
||||
assert call_kwargs["education_theme"] == "勇气与友谊"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="POST /api/image/generate/{id} 端点尚未实现")
|
||||
class TestImageGenerateSuccess:
|
||||
"""封面图片生成成功测试。"""
|
||||
|
||||
|
||||
@@ -1,17 +1,24 @@
|
||||
version: '3.8'
|
||||
# docker-compose.yml
|
||||
# 开发环境配置 - 支持本地构建和快速迭代
|
||||
#
|
||||
# 使用方式:
|
||||
# docker compose up -d # 启动所有服务
|
||||
# docker compose up -d --build # 重新构建并启动
|
||||
# docker compose logs -f backend # 查看日志
|
||||
#
|
||||
# 生产部署请使用: docker-compose.prod.yml
|
||||
|
||||
services:
|
||||
# ==============================================
|
||||
# 前端服务 (C端用户 App)
|
||||
# ==============================================
|
||||
frontend:
|
||||
build:
|
||||
context: ./frontend
|
||||
dockerfile: Dockerfile
|
||||
build: ./frontend
|
||||
image: dreamweaver-frontend:dev
|
||||
container_name: dreamweaver_frontend
|
||||
restart: always
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "52080:80" # User App UI
|
||||
- "52080:80"
|
||||
depends_on:
|
||||
- backend
|
||||
|
||||
@@ -19,13 +26,12 @@ services:
|
||||
# 管理后台前端 (Admin Console)
|
||||
# ==============================================
|
||||
frontend-admin:
|
||||
build:
|
||||
context: ./admin-frontend
|
||||
dockerfile: Dockerfile
|
||||
build: ./admin-frontend
|
||||
image: dreamweaver-admin-frontend:dev
|
||||
container_name: dreamweaver_frontend_admin
|
||||
restart: always
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "52888:80" # Admin Console UI
|
||||
- "52888:80"
|
||||
depends_on:
|
||||
- backend-admin
|
||||
|
||||
@@ -33,19 +39,19 @@ services:
|
||||
# 后端服务 (FastAPI)
|
||||
# ==============================================
|
||||
backend:
|
||||
build:
|
||||
context: ./backend
|
||||
dockerfile: Dockerfile
|
||||
build: ./backend
|
||||
image: dreamweaver-backend:dev
|
||||
container_name: dreamweaver_backend
|
||||
restart: always
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "52000:8000" # User App API
|
||||
- "52000:8000"
|
||||
env_file:
|
||||
- ./backend/.env
|
||||
environment:
|
||||
- DATABASE_URL=postgresql+asyncpg://${POSTGRES_USER:-dreamweaver}:${POSTGRES_PASSWORD:-dreamweaver_password}@db:5432/${POSTGRES_DB:-dreamweaver_db}
|
||||
- CELERY_BROKER_URL=redis://redis:6379/0
|
||||
- CELERY_RESULT_BACKEND=redis://redis:6379/0
|
||||
- REDIS_URL=redis://redis:6379/0
|
||||
volumes:
|
||||
- backend_static:/app/static
|
||||
depends_on:
|
||||
@@ -55,72 +61,73 @@ services:
|
||||
condition: service_started
|
||||
|
||||
# ==============================================
|
||||
# 管理后台后端 (Admin Backend)
|
||||
# 管理后台后端 (Admin Backend) - 复用 backend 镜像
|
||||
# ==============================================
|
||||
backend-admin:
|
||||
build:
|
||||
context: ./backend
|
||||
dockerfile: Dockerfile
|
||||
image: dreamweaver-backend:dev
|
||||
container_name: dreamweaver_backend_admin
|
||||
restart: always
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "52800:8001" # Admin API
|
||||
- "52800:8001"
|
||||
command: ["uvicorn", "app.admin_main:app", "--host", "0.0.0.0", "--port", "8001"]
|
||||
env_file:
|
||||
- ./backend/.env
|
||||
environment:
|
||||
# 复用相同的 DB/Redis 连接
|
||||
- DATABASE_URL=postgresql+asyncpg://${POSTGRES_USER:-dreamweaver}:${POSTGRES_PASSWORD:-dreamweaver_password}@db:5432/${POSTGRES_DB:-dreamweaver_db}
|
||||
- CELERY_BROKER_URL=redis://redis:6379/0
|
||||
- CELERY_RESULT_BACKEND=redis://redis:6379/0
|
||||
- REDIS_URL=redis://redis:6379/0
|
||||
volumes:
|
||||
- backend_static:/app/static
|
||||
depends_on:
|
||||
backend:
|
||||
condition: service_started
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_started
|
||||
|
||||
# ==============================================
|
||||
|
||||
# ==============================================
|
||||
# 工作节点 (Celery Worker)
|
||||
# ==============================================
|
||||
# ==============================================
|
||||
# 工作节点 (Celery Worker)
|
||||
# 工作节点 (Celery Worker) - 复用 backend 镜像
|
||||
# ==============================================
|
||||
worker:
|
||||
build:
|
||||
context: ./backend
|
||||
image: dreamweaver-backend:dev
|
||||
container_name: dreamweaver_worker
|
||||
command: celery -A app.core.celery_app worker --loglevel=info
|
||||
restart: always
|
||||
restart: unless-stopped
|
||||
env_file: ./backend/.env
|
||||
environment:
|
||||
- DATABASE_URL=postgresql+asyncpg://${POSTGRES_USER:-dreamweaver}:${POSTGRES_PASSWORD:-dreamweaver_password}@db:5432/${POSTGRES_DB:-dreamweaver_db}
|
||||
- CELERY_BROKER_URL=redis://redis:6379/0
|
||||
- CELERY_RESULT_BACKEND=redis://redis:6379/0
|
||||
- REDIS_URL=redis://redis:6379/0
|
||||
depends_on:
|
||||
- backend
|
||||
- redis
|
||||
backend:
|
||||
condition: service_started
|
||||
redis:
|
||||
condition: service_started
|
||||
db:
|
||||
condition: service_healthy
|
||||
|
||||
# ==============================================
|
||||
# 调度节点 (Celery Beat)
|
||||
# 调度节点 (Celery Beat) - 复用 backend 镜像
|
||||
# ==============================================
|
||||
celery-beat:
|
||||
build:
|
||||
context: ./backend
|
||||
image: dreamweaver-backend:dev
|
||||
container_name: dreamweaver_beat
|
||||
command: celery -A app.core.celery_app beat --loglevel=info
|
||||
restart: always
|
||||
restart: unless-stopped
|
||||
env_file: ./backend/.env
|
||||
environment:
|
||||
- DATABASE_URL=postgresql+asyncpg://${POSTGRES_USER:-dreamweaver}:${POSTGRES_PASSWORD:-dreamweaver_password}@db:5432/${POSTGRES_DB:-dreamweaver_db}
|
||||
- CELERY_BROKER_URL=redis://redis:6379/0
|
||||
- CELERY_RESULT_BACKEND=redis://redis:6379/0
|
||||
- REDIS_URL=redis://redis:6379/0
|
||||
depends_on:
|
||||
- backend
|
||||
- redis
|
||||
backend:
|
||||
condition: service_started
|
||||
redis:
|
||||
condition: service_started
|
||||
|
||||
# ==============================================
|
||||
# 数据库 (PostgreSQL)
|
||||
@@ -128,13 +135,13 @@ services:
|
||||
db:
|
||||
image: postgres:15-alpine
|
||||
container_name: dreamweaver_db
|
||||
restart: always
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
POSTGRES_USER: ${POSTGRES_USER:-dreamweaver}
|
||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-dreamweaver_password}
|
||||
POSTGRES_DB: ${POSTGRES_DB:-dreamweaver_db}
|
||||
ports:
|
||||
- "52432:5432" # DB Host Port
|
||||
- "52432:5432"
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
@@ -149,24 +156,26 @@ services:
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
container_name: dreamweaver_redis
|
||||
restart: always
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "52379:6379" # Redis Host Port
|
||||
- "52379:6379"
|
||||
volumes:
|
||||
- redis_data:/data
|
||||
command: redis-server --appendonly yes
|
||||
|
||||
# ==============================================
|
||||
# 数据库管理 (Adminer)
|
||||
# 数据库管理 (Adminer) - 仅开发环境
|
||||
# ==============================================
|
||||
adminer:
|
||||
image: adminer
|
||||
container_name: dreamweaver_adminer
|
||||
restart: always
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "52999:8080" # Adminer UI
|
||||
- "52999:8080"
|
||||
depends_on:
|
||||
- db
|
||||
profiles:
|
||||
- dev # 仅在 --profile dev 时启动
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
|
||||
Reference in New Issue
Block a user