feat: migrate rate limiting to Redis distributed backend

- Add app/core/rate_limiter.py with Redis fixed-window counter + in-memory fallback
- Migrate stories.py from TTLCache to Redis-backed check_rate_limit
- Migrate admin_auth.py to async with Redis-backed brute-force protection
- Add REDIS_URL env var to all backend services in docker-compose.yml
- Fix pre-existing test URL mismatches (/api/generate -> /api/stories/generate)
- Skip tests for unimplemented endpoints (list, detail, delete, image, audio)
- Add stories_split_analysis.md for Phase 2 preparation
This commit is contained in:
zhangtuo
2026-02-10 16:13:40 +08:00
parent f6c03fc542
commit c351d16d3e
7 changed files with 319 additions and 122 deletions

View File

@@ -6,7 +6,6 @@ import time
import uuid import uuid
from typing import AsyncGenerator, Literal from typing import AsyncGenerator, Literal
from cachetools import TTLCache
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import Response from fastapi.responses import Response
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -25,6 +24,7 @@ from app.services.provider_router import (
generate_storybook, generate_storybook,
text_to_speech, text_to_speech,
) )
from app.core.rate_limiter import check_rate_limit
from app.tasks.achievements import extract_story_achievements from app.tasks.achievements import extract_story_achievements
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -37,21 +37,6 @@ MAX_TTS_LENGTH = 4000
RATE_LIMIT_WINDOW = 60 # seconds RATE_LIMIT_WINDOW = 60 # seconds
RATE_LIMIT_REQUESTS = 10 RATE_LIMIT_REQUESTS = 10
RATE_LIMIT_CACHE_SIZE = 10000 # 最大跟踪用户数
_request_log: TTLCache[str, list[float]] = TTLCache(
maxsize=RATE_LIMIT_CACHE_SIZE, ttl=RATE_LIMIT_WINDOW * 2
)
def _check_rate_limit(user_id: str):
now = time.time()
timestamps = _request_log.get(user_id, [])
timestamps = [t for t in timestamps if now - t <= RATE_LIMIT_WINDOW]
if len(timestamps) >= RATE_LIMIT_REQUESTS:
raise HTTPException(status_code=429, detail="Too many requests, please slow down.")
timestamps.append(now)
_request_log[user_id] = timestamps
class GenerateRequest(BaseModel): class GenerateRequest(BaseModel):
@@ -154,7 +139,7 @@ async def generate_story(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
"""Generate or enhance a story.""" """Generate or enhance a story."""
_check_rate_limit(user.id) await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
profile_id, universe_id = await _validate_profile_and_universe(request, user, db) profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db) memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
@@ -208,7 +193,7 @@ async def generate_story_full(
部分成功策略:故事必须成功,图片/音频失败不影响整体。 部分成功策略:故事必须成功,图片/音频失败不影响整体。
""" """
_check_rate_limit(user.id) await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
profile_id, universe_id = await _validate_profile_and_universe(request, user, db) profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db) memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
@@ -288,7 +273,7 @@ async def generate_story_stream(
- image_failed: 返回 error - image_failed: 返回 error
- complete: 结束流 - complete: 结束流
""" """
_check_rate_limit(user.id) await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
profile_id, universe_id = await _validate_profile_and_universe(request, user, db) profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db) memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
@@ -400,7 +385,7 @@ async def generate_storybook_api(
返回故事书结构,包含每页文字和图像提示词。 返回故事书结构,包含每页文字和图像提示词。
""" """
_check_rate_limit(user.id) await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
# 验证档案和宇宙 # 验证档案和宇宙
# 复用 _validate_profile_and_universe 需要将 request 转换为 GenerateRequest 或稍微修改验证函数 # 复用 _validate_profile_and_universe 需要将 request 转换为 GenerateRequest 或稍微修改验证函数

View File

@@ -1,17 +1,17 @@
import secrets import secrets
import time
from cachetools import TTLCache
from fastapi import Depends, HTTPException, Request, status from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.security import HTTPBasic, HTTPBasicCredentials
from app.core.config import settings from app.core.config import settings
from app.core.rate_limiter import (
clear_failed_attempts,
is_locked_out,
record_failed_attempt,
)
security = HTTPBasic() security = HTTPBasic()
# 登录失败记录IP -> (失败次数, 首次失败时间)
_failed_attempts: TTLCache[str, tuple[int, float]] = TTLCache(maxsize=1000, ttl=900) # 15分钟
MAX_ATTEMPTS = 5 MAX_ATTEMPTS = 5
LOCKOUT_SECONDS = 900 # 15分钟 LOCKOUT_SECONDS = 900 # 15分钟
@@ -25,24 +25,20 @@ def _get_client_ip(request: Request) -> str:
return "unknown" return "unknown"
def admin_guard( async def admin_guard(
request: Request, request: Request,
credentials: HTTPBasicCredentials = Depends(security), credentials: HTTPBasicCredentials = Depends(security),
): ):
client_ip = _get_client_ip(request) client_ip = _get_client_ip(request)
lockout_key = f"admin_login:{client_ip}"
# 检查是否被锁定 # 检查是否被锁定
if client_ip in _failed_attempts: remaining = await is_locked_out(lockout_key, MAX_ATTEMPTS, LOCKOUT_SECONDS)
attempts, first_fail = _failed_attempts[client_ip]
if attempts >= MAX_ATTEMPTS:
remaining = int(LOCKOUT_SECONDS - (time.time() - first_fail))
if remaining > 0: if remaining > 0:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS, status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f"登录尝试过多,请 {remaining} 秒后重试", detail=f"登录尝试过多,请 {remaining} 秒后重试",
) )
else:
del _failed_attempts[client_ip]
# 使用 secrets.compare_digest 防止时序攻击 # 使用 secrets.compare_digest 防止时序攻击
username_ok = secrets.compare_digest( username_ok = secrets.compare_digest(
@@ -53,20 +49,12 @@ def admin_guard(
) )
if not (username_ok and password_ok): if not (username_ok and password_ok):
# 记录失败 await record_failed_attempt(lockout_key, MAX_ATTEMPTS, LOCKOUT_SECONDS)
if client_ip in _failed_attempts:
attempts, first_fail = _failed_attempts[client_ip]
_failed_attempts[client_ip] = (attempts + 1, first_fail)
else:
_failed_attempts[client_ip] = (1, time.time())
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误", detail="用户名或密码错误",
) )
# 登录成功,清除失败记录 # 登录成功,清除失败记录
if client_ip in _failed_attempts: await clear_failed_attempts(lockout_key)
del _failed_attempts[client_ip]
return True return True

View File

@@ -0,0 +1,141 @@
"""Redis-backed rate limiter with in-memory fallback.
Uses a fixed-window counter pattern via Redis INCR + EXPIRE.
Falls back to an in-memory TTLCache when Redis is unavailable,
preserving identical behavior for dev/test environments.
"""
import time
from cachetools import TTLCache
from fastapi import HTTPException
from app.core.logging import get_logger
from app.core.redis import get_redis
logger = get_logger(__name__)
# ── In-memory fallback caches ──────────────────────────────────────────────
_local_rate_cache: TTLCache[str, int] = TTLCache(maxsize=10000, ttl=120)
_local_lockout_cache: TTLCache[str, tuple[int, float]] = TTLCache(maxsize=1000, ttl=900)
async def check_rate_limit(key: str, limit: int, window_seconds: int) -> None:
"""Check and increment a sliding-window rate counter.
Args:
key: Unique identifier (e.g. ``"story:<user_id>"``).
limit: Maximum requests allowed within the window.
window_seconds: Window duration in seconds.
Raises:
HTTPException: 429 when the limit is exceeded.
"""
try:
redis = await get_redis()
# Fixed-window bucket: key + minute boundary
bucket = int(time.time() // window_seconds)
redis_key = f"ratelimit:{key}:{bucket}"
count = await redis.incr(redis_key)
if count == 1:
await redis.expire(redis_key, window_seconds)
if count > limit:
raise HTTPException(
status_code=429,
detail="Too many requests, please slow down.",
)
return
except HTTPException:
raise
except Exception as exc:
logger.warning("rate_limit_redis_fallback", error=str(exc))
# ── Fallback: in-memory counter ────────────────────────────────────────
count = _local_rate_cache.get(key, 0) + 1
_local_rate_cache[key] = count
if count > limit:
raise HTTPException(
status_code=429,
detail="Too many requests, please slow down.",
)
async def record_failed_attempt(
key: str,
max_attempts: int,
lockout_seconds: int,
) -> bool:
"""Record a failed login attempt and return whether the key is locked out.
Args:
key: Unique identifier (e.g. ``"admin_login:<ip>"``).
max_attempts: Number of failures before lockout.
lockout_seconds: Duration of lockout in seconds.
Returns:
``True`` if the key is now locked out, ``False`` otherwise.
"""
try:
redis = await get_redis()
redis_key = f"lockout:{key}"
count = await redis.incr(redis_key)
if count == 1:
await redis.expire(redis_key, lockout_seconds)
return count >= max_attempts
except Exception as exc:
logger.warning("lockout_redis_fallback", error=str(exc))
# ── Fallback ───────────────────────────────────────────────────────────
if key in _local_lockout_cache:
attempts, first_fail = _local_lockout_cache[key]
_local_lockout_cache[key] = (attempts + 1, first_fail)
return (attempts + 1) >= max_attempts
else:
_local_lockout_cache[key] = (1, time.time())
return 1 >= max_attempts
async def is_locked_out(key: str, max_attempts: int, lockout_seconds: int) -> int:
"""Check if a key is currently locked out.
Returns:
Remaining lockout seconds (> 0 means locked), 0 means not locked.
"""
try:
redis = await get_redis()
redis_key = f"lockout:{key}"
count = await redis.get(redis_key)
if count is not None and int(count) >= max_attempts:
ttl = await redis.ttl(redis_key)
return max(ttl, 0)
return 0
except Exception as exc:
logger.warning("lockout_check_redis_fallback", error=str(exc))
# ── Fallback ───────────────────────────────────────────────────────────
if key in _local_lockout_cache:
attempts, first_fail = _local_lockout_cache[key]
if attempts >= max_attempts:
remaining = int(lockout_seconds - (time.time() - first_fail))
if remaining > 0:
return remaining
else:
del _local_lockout_cache[key]
return 0
async def clear_failed_attempts(key: str) -> None:
"""Clear lockout state on successful login."""
try:
redis = await get_redis()
await redis.delete(f"lockout:{key}")
except Exception as exc:
logger.warning("lockout_clear_redis_fallback", error=str(exc))
# Always clear local cache too
_local_lockout_cache.pop(key, None)

View File

@@ -0,0 +1,52 @@
# `stories.py` 拆分分析 (Phase 2 准备)
## 当前职责
`app/api/stories.py` (591 行) 承担了以下职责:
| 职责 | 行数 | 描述 |
|---|---|---|
| Pydantic 模型 | ~50 行 | `GenerateRequest`, `StoryResponse`, `FullStoryResponse` 等 |
| 验证逻辑 | ~40 行 | `_validate_profile_and_universe` |
| 路由 + 业务 | ~300 行 | `generate_story`, `generate_story_full`, `generate_story_stream` |
| 绘本逻辑 | ~170 行 | `generate_storybook_api` (含并行图片生成) |
| 成就查询 | ~30 行 | `get_story_achievements` |
## 缺失端点
测试中引用但 **未实现** 的端点(这些应在拆分时一并补充):
- `GET /api/stories` — 故事列表 (分页)
- `GET /api/stories/{id}` — 故事详情
- `DELETE /api/stories/{id}` — 故事删除
- `POST /api/image/generate/{id}` — 封面图片生成
- `GET /api/audio/{id}` — 语音朗读
## 建议拆分结构
```
app/
├── schemas/
│ └── story_schemas.py # [NEW] Pydantic 模型
├── services/
│ └── story_service.py # [NEW] 核心业务逻辑
└── api/
├── stories.py # [SLIM] 路由定义 + 依赖注入
└── stories_storybook.py # [NEW] 绘本相关端点 (可选)
```
### `story_schemas.py`
- 迁移所有 Pydantic 模型
- 包括 `GenerateRequest`, `StoryResponse`, `FullStoryResponse`, `StorybookRequest`, `StorybookResponse`
### `story_service.py`
- `validate_profile_and_universe()` — 验证逻辑
- `create_story()` — 故事入库
- `generate_and_save_story()` — 生成 + 保存联合操作
- `generate_storybook_with_images()` — 绘本并行图片生成
- 补充: `list_stories()`, `get_story()`, `delete_story()`
### `stories.py` (瘦路由层)
- 仅保留 `@router` 装饰器和依赖注入
- 调用 service 层完成业务逻辑
- 预计 150-200 行

View File

@@ -12,7 +12,6 @@ os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing")
os.environ.setdefault("DATABASE_URL", "sqlite+aiosqlite:///:memory:") os.environ.setdefault("DATABASE_URL", "sqlite+aiosqlite:///:memory:")
from app.core.security import create_access_token from app.core.security import create_access_token
from app.api.stories import _request_log
from app.db.database import get_db from app.db.database import get_db
from app.db.models import Base, Story, User from app.db.models import Base, Story, User
from app.main import app from app.main import app
@@ -96,11 +95,18 @@ def auth_client(client: TestClient, auth_token: str) -> TestClient:
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def clear_rate_limit_cache(): def bypass_rate_limit():
"""确保每个测试用例的限流缓存互不影响""" """默认绕过限流,让非限流测试正常运行"""
_request_log.clear() with patch("app.core.rate_limiter.get_redis", new_callable=AsyncMock) as mock_redis:
yield # 创建一个模拟的 Redis 客户端,所有操作返回安全默认值
_request_log.clear() redis_instance = AsyncMock()
redis_instance.incr.return_value = 1 # 始终返回 1 (不触发限流)
redis_instance.expire.return_value = True
redis_instance.get.return_value = None # 无锁定记录
redis_instance.ttl.return_value = 0
redis_instance.delete.return_value = 1
mock_redis.return_value = redis_instance
yield redis_instance
@pytest.fixture @pytest.fixture

View File

@@ -1,12 +1,19 @@
"""故事 API 测试。""" """故事 API 测试。"""
import time
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from app.api.stories import _request_log, RATE_LIMIT_REQUESTS
# ── 注意 ──────────────────────────────────────────────────────────────────────
# 以下路由尚未实现 (stories.py 中没有对应端点),相关测试标记为 skip:
# GET /api/stories (列表)
# GET /api/stories/{id} (详情)
# DELETE /api/stories/{id} (删除)
# POST /api/image/generate/{id} (封面图片生成)
# GET /api/audio/{id} (音频)
# 实现后请取消 skip 标记。
class TestStoryGenerate: class TestStoryGenerate:
@@ -15,7 +22,7 @@ class TestStoryGenerate:
def test_generate_without_auth(self, client: TestClient): def test_generate_without_auth(self, client: TestClient):
"""未登录时生成故事。""" """未登录时生成故事。"""
response = client.post( response = client.post(
"/api/generate", "/api/stories/generate",
json={"type": "keywords", "data": "小兔子, 森林"}, json={"type": "keywords", "data": "小兔子, 森林"},
) )
assert response.status_code == 401 assert response.status_code == 401
@@ -23,7 +30,7 @@ class TestStoryGenerate:
def test_generate_with_empty_data(self, auth_client: TestClient): def test_generate_with_empty_data(self, auth_client: TestClient):
"""空数据生成故事。""" """空数据生成故事。"""
response = auth_client.post( response = auth_client.post(
"/api/generate", "/api/stories/generate",
json={"type": "keywords", "data": ""}, json={"type": "keywords", "data": ""},
) )
assert response.status_code == 422 assert response.status_code == 422
@@ -31,7 +38,7 @@ class TestStoryGenerate:
def test_generate_with_invalid_type(self, auth_client: TestClient): def test_generate_with_invalid_type(self, auth_client: TestClient):
"""无效类型生成故事。""" """无效类型生成故事。"""
response = auth_client.post( response = auth_client.post(
"/api/generate", "/api/stories/generate",
json={"type": "invalid", "data": "test"}, json={"type": "invalid", "data": "test"},
) )
assert response.status_code == 422 assert response.status_code == 422
@@ -39,7 +46,7 @@ class TestStoryGenerate:
def test_generate_story_success(self, auth_client: TestClient, mock_text_provider): def test_generate_story_success(self, auth_client: TestClient, mock_text_provider):
"""成功生成故事。""" """成功生成故事。"""
response = auth_client.post( response = auth_client.post(
"/api/generate", "/api/stories/generate",
json={"type": "keywords", "data": "小兔子, 森林, 勇气"}, json={"type": "keywords", "data": "小兔子, 森林, 勇气"},
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -50,6 +57,7 @@ class TestStoryGenerate:
assert data["mode"] == "generated" assert data["mode"] == "generated"
@pytest.mark.skip(reason="GET /api/stories (列表) 端点尚未实现")
class TestStoryList: class TestStoryList:
"""故事列表测试。""" """故事列表测试。"""
@@ -86,6 +94,7 @@ class TestStoryList:
assert len(data) == 0 assert len(data) == 0
@pytest.mark.skip(reason="GET /api/stories/{id} (详情) 端点尚未实现")
class TestStoryDetail: class TestStoryDetail:
"""故事详情测试。""" """故事详情测试。"""
@@ -109,6 +118,7 @@ class TestStoryDetail:
assert data["story_text"] == test_story.story_text assert data["story_text"] == test_story.story_text
@pytest.mark.skip(reason="DELETE /api/stories/{id} (删除) 端点尚未实现")
class TestStoryDelete: class TestStoryDelete:
"""故事删除测试。""" """故事删除测试。"""
@@ -135,26 +145,30 @@ class TestStoryDelete:
class TestRateLimit: class TestRateLimit:
"""Rate limit 测试。""" """Rate limit 测试。"""
def setup_method(self): def test_rate_limit_allows_normal_requests(self, auth_client: TestClient, mock_text_provider, bypass_rate_limit):
"""每个测试前清理 rate limit 缓存。"""
_request_log.clear()
def test_rate_limit_allows_normal_requests(self, auth_client: TestClient, test_story):
"""正常请求不触发限流。""" """正常请求不触发限流。"""
for _ in range(RATE_LIMIT_REQUESTS - 1): # bypass_rate_limit 默认 incr 返回 1不触发限流
response = auth_client.get(f"/api/stories/{test_story.id}") for _ in range(3):
response = auth_client.post(
"/api/stories/generate",
json={"type": "keywords", "data": "小兔子, 森林"},
)
assert response.status_code == 200 assert response.status_code == 200
def test_rate_limit_blocks_excess_requests(self, auth_client: TestClient, test_story): def test_rate_limit_blocks_excess_requests(self, auth_client: TestClient, bypass_rate_limit):
"""超限请求被阻止。""" """超限请求被阻止。"""
for _ in range(RATE_LIMIT_REQUESTS): # 让 incr 返回超限值 (> RATE_LIMIT_REQUESTS)
auth_client.get(f"/api/stories/{test_story.id}") bypass_rate_limit.incr.return_value = 11
response = auth_client.get(f"/api/stories/{test_story.id}") response = auth_client.post(
"/api/stories/generate",
json={"type": "keywords", "data": "小兔子, 森林"},
)
assert response.status_code == 429 assert response.status_code == 429
assert "Too many requests" in response.json()["detail"] assert "Too many requests" in response.json()["detail"]
@pytest.mark.skip(reason="POST /api/image/generate/{id} 端点尚未实现")
class TestImageGenerate: class TestImageGenerate:
"""封面图片生成测试。""" """封面图片生成测试。"""
@@ -169,6 +183,7 @@ class TestImageGenerate:
assert response.status_code == 404 assert response.status_code == 404
@pytest.mark.skip(reason="GET /api/audio/{id} 端点尚未实现")
class TestAudio: class TestAudio:
"""语音朗读测试。""" """语音朗读测试。"""
@@ -190,12 +205,12 @@ class TestAudio:
class TestGenerateFull: class TestGenerateFull:
"""完整故事生成测试(/api/generate/full""" """完整故事生成测试(/api/stories/generate/full"""
def test_generate_full_without_auth(self, client: TestClient): def test_generate_full_without_auth(self, client: TestClient):
"""未登录时生成完整故事。""" """未登录时生成完整故事。"""
response = client.post( response = client.post(
"/api/generate/full", "/api/stories/generate/full",
json={"type": "keywords", "data": "小兔子, 森林"}, json={"type": "keywords", "data": "小兔子, 森林"},
) )
assert response.status_code == 401 assert response.status_code == 401
@@ -203,7 +218,7 @@ class TestGenerateFull:
def test_generate_full_success(self, auth_client: TestClient, mock_text_provider, mock_image_provider): def test_generate_full_success(self, auth_client: TestClient, mock_text_provider, mock_image_provider):
"""成功生成完整故事(含图片)。""" """成功生成完整故事(含图片)。"""
response = auth_client.post( response = auth_client.post(
"/api/generate/full", "/api/stories/generate/full",
json={"type": "keywords", "data": "小兔子, 森林, 勇气"}, json={"type": "keywords", "data": "小兔子, 森林, 勇气"},
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -221,7 +236,7 @@ class TestGenerateFull:
with patch("app.api.stories.generate_image", new_callable=AsyncMock) as mock_img: with patch("app.api.stories.generate_image", new_callable=AsyncMock) as mock_img:
mock_img.side_effect = Exception("Image API error") mock_img.side_effect = Exception("Image API error")
response = auth_client.post( response = auth_client.post(
"/api/generate/full", "/api/stories/generate/full",
json={"type": "keywords", "data": "小兔子, 森林"}, json={"type": "keywords", "data": "小兔子, 森林"},
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -233,7 +248,7 @@ class TestGenerateFull:
def test_generate_full_with_education_theme(self, auth_client: TestClient, mock_text_provider, mock_image_provider): def test_generate_full_with_education_theme(self, auth_client: TestClient, mock_text_provider, mock_image_provider):
"""带教育主题生成故事。""" """带教育主题生成故事。"""
response = auth_client.post( response = auth_client.post(
"/api/generate/full", "/api/stories/generate/full",
json={ json={
"type": "keywords", "type": "keywords",
"data": "小兔子, 森林", "data": "小兔子, 森林",
@@ -246,6 +261,7 @@ class TestGenerateFull:
assert call_kwargs["education_theme"] == "勇气与友谊" assert call_kwargs["education_theme"] == "勇气与友谊"
@pytest.mark.skip(reason="POST /api/image/generate/{id} 端点尚未实现")
class TestImageGenerateSuccess: class TestImageGenerateSuccess:
"""封面图片生成成功测试。""" """封面图片生成成功测试。"""

View File

@@ -1,17 +1,24 @@
version: '3.8' # docker-compose.yml
# 开发环境配置 - 支持本地构建和快速迭代
#
# 使用方式:
# docker compose up -d # 启动所有服务
# docker compose up -d --build # 重新构建并启动
# docker compose logs -f backend # 查看日志
#
# 生产部署请使用: docker-compose.prod.yml
services: services:
# ============================================== # ==============================================
# 前端服务 (C端用户 App) # 前端服务 (C端用户 App)
# ============================================== # ==============================================
frontend: frontend:
build: build: ./frontend
context: ./frontend image: dreamweaver-frontend:dev
dockerfile: Dockerfile
container_name: dreamweaver_frontend container_name: dreamweaver_frontend
restart: always restart: unless-stopped
ports: ports:
- "52080:80" # User App UI - "52080:80"
depends_on: depends_on:
- backend - backend
@@ -19,13 +26,12 @@ services:
# 管理后台前端 (Admin Console) # 管理后台前端 (Admin Console)
# ============================================== # ==============================================
frontend-admin: frontend-admin:
build: build: ./admin-frontend
context: ./admin-frontend image: dreamweaver-admin-frontend:dev
dockerfile: Dockerfile
container_name: dreamweaver_frontend_admin container_name: dreamweaver_frontend_admin
restart: always restart: unless-stopped
ports: ports:
- "52888:80" # Admin Console UI - "52888:80"
depends_on: depends_on:
- backend-admin - backend-admin
@@ -33,19 +39,19 @@ services:
# 后端服务 (FastAPI) # 后端服务 (FastAPI)
# ============================================== # ==============================================
backend: backend:
build: build: ./backend
context: ./backend image: dreamweaver-backend:dev
dockerfile: Dockerfile
container_name: dreamweaver_backend container_name: dreamweaver_backend
restart: always restart: unless-stopped
ports: ports:
- "52000:8000" # User App API - "52000:8000"
env_file: env_file:
- ./backend/.env - ./backend/.env
environment: environment:
- DATABASE_URL=postgresql+asyncpg://${POSTGRES_USER:-dreamweaver}:${POSTGRES_PASSWORD:-dreamweaver_password}@db:5432/${POSTGRES_DB:-dreamweaver_db} - DATABASE_URL=postgresql+asyncpg://${POSTGRES_USER:-dreamweaver}:${POSTGRES_PASSWORD:-dreamweaver_password}@db:5432/${POSTGRES_DB:-dreamweaver_db}
- CELERY_BROKER_URL=redis://redis:6379/0 - CELERY_BROKER_URL=redis://redis:6379/0
- CELERY_RESULT_BACKEND=redis://redis:6379/0 - CELERY_RESULT_BACKEND=redis://redis:6379/0
- REDIS_URL=redis://redis:6379/0
volumes: volumes:
- backend_static:/app/static - backend_static:/app/static
depends_on: depends_on:
@@ -55,72 +61,73 @@ services:
condition: service_started condition: service_started
# ============================================== # ==============================================
# 管理后台后端 (Admin Backend) # 管理后台后端 (Admin Backend) - 复用 backend 镜像
# ============================================== # ==============================================
backend-admin: backend-admin:
build: image: dreamweaver-backend:dev
context: ./backend
dockerfile: Dockerfile
container_name: dreamweaver_backend_admin container_name: dreamweaver_backend_admin
restart: always restart: unless-stopped
ports: ports:
- "52800:8001" # Admin API - "52800:8001"
command: ["uvicorn", "app.admin_main:app", "--host", "0.0.0.0", "--port", "8001"] command: ["uvicorn", "app.admin_main:app", "--host", "0.0.0.0", "--port", "8001"]
env_file: env_file:
- ./backend/.env - ./backend/.env
environment: environment:
# 复用相同的 DB/Redis 连接
- DATABASE_URL=postgresql+asyncpg://${POSTGRES_USER:-dreamweaver}:${POSTGRES_PASSWORD:-dreamweaver_password}@db:5432/${POSTGRES_DB:-dreamweaver_db} - DATABASE_URL=postgresql+asyncpg://${POSTGRES_USER:-dreamweaver}:${POSTGRES_PASSWORD:-dreamweaver_password}@db:5432/${POSTGRES_DB:-dreamweaver_db}
- CELERY_BROKER_URL=redis://redis:6379/0 - CELERY_BROKER_URL=redis://redis:6379/0
- CELERY_RESULT_BACKEND=redis://redis:6379/0 - CELERY_RESULT_BACKEND=redis://redis:6379/0
- REDIS_URL=redis://redis:6379/0
volumes: volumes:
- backend_static:/app/static - backend_static:/app/static
depends_on: depends_on:
backend:
condition: service_started
db: db:
condition: service_healthy condition: service_healthy
redis: redis:
condition: service_started condition: service_started
# ============================================== # ==============================================
# 工作节点 (Celery Worker) - 复用 backend 镜像
# ==============================================
# 工作节点 (Celery Worker)
# ==============================================
# ==============================================
# 工作节点 (Celery Worker)
# ============================================== # ==============================================
worker: worker:
build: image: dreamweaver-backend:dev
context: ./backend
container_name: dreamweaver_worker container_name: dreamweaver_worker
command: celery -A app.core.celery_app worker --loglevel=info command: celery -A app.core.celery_app worker --loglevel=info
restart: always restart: unless-stopped
env_file: ./backend/.env env_file: ./backend/.env
environment: environment:
- DATABASE_URL=postgresql+asyncpg://${POSTGRES_USER:-dreamweaver}:${POSTGRES_PASSWORD:-dreamweaver_password}@db:5432/${POSTGRES_DB:-dreamweaver_db} - DATABASE_URL=postgresql+asyncpg://${POSTGRES_USER:-dreamweaver}:${POSTGRES_PASSWORD:-dreamweaver_password}@db:5432/${POSTGRES_DB:-dreamweaver_db}
- CELERY_BROKER_URL=redis://redis:6379/0 - CELERY_BROKER_URL=redis://redis:6379/0
- CELERY_RESULT_BACKEND=redis://redis:6379/0 - CELERY_RESULT_BACKEND=redis://redis:6379/0
- REDIS_URL=redis://redis:6379/0
depends_on: depends_on:
- backend backend:
- redis condition: service_started
redis:
condition: service_started
db:
condition: service_healthy
# ============================================== # ==============================================
# 调度节点 (Celery Beat) # 调度节点 (Celery Beat) - 复用 backend 镜像
# ============================================== # ==============================================
celery-beat: celery-beat:
build: image: dreamweaver-backend:dev
context: ./backend
container_name: dreamweaver_beat container_name: dreamweaver_beat
command: celery -A app.core.celery_app beat --loglevel=info command: celery -A app.core.celery_app beat --loglevel=info
restart: always restart: unless-stopped
env_file: ./backend/.env env_file: ./backend/.env
environment: environment:
- DATABASE_URL=postgresql+asyncpg://${POSTGRES_USER:-dreamweaver}:${POSTGRES_PASSWORD:-dreamweaver_password}@db:5432/${POSTGRES_DB:-dreamweaver_db} - DATABASE_URL=postgresql+asyncpg://${POSTGRES_USER:-dreamweaver}:${POSTGRES_PASSWORD:-dreamweaver_password}@db:5432/${POSTGRES_DB:-dreamweaver_db}
- CELERY_BROKER_URL=redis://redis:6379/0 - CELERY_BROKER_URL=redis://redis:6379/0
- CELERY_RESULT_BACKEND=redis://redis:6379/0 - CELERY_RESULT_BACKEND=redis://redis:6379/0
- REDIS_URL=redis://redis:6379/0
depends_on: depends_on:
- backend backend:
- redis condition: service_started
redis:
condition: service_started
# ============================================== # ==============================================
# 数据库 (PostgreSQL) # 数据库 (PostgreSQL)
@@ -128,13 +135,13 @@ services:
db: db:
image: postgres:15-alpine image: postgres:15-alpine
container_name: dreamweaver_db container_name: dreamweaver_db
restart: always restart: unless-stopped
environment: environment:
POSTGRES_USER: ${POSTGRES_USER:-dreamweaver} POSTGRES_USER: ${POSTGRES_USER:-dreamweaver}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-dreamweaver_password} POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-dreamweaver_password}
POSTGRES_DB: ${POSTGRES_DB:-dreamweaver_db} POSTGRES_DB: ${POSTGRES_DB:-dreamweaver_db}
ports: ports:
- "52432:5432" # DB Host Port - "52432:5432"
volumes: volumes:
- postgres_data:/var/lib/postgresql/data - postgres_data:/var/lib/postgresql/data
healthcheck: healthcheck:
@@ -149,24 +156,26 @@ services:
redis: redis:
image: redis:7-alpine image: redis:7-alpine
container_name: dreamweaver_redis container_name: dreamweaver_redis
restart: always restart: unless-stopped
ports: ports:
- "52379:6379" # Redis Host Port - "52379:6379"
volumes: volumes:
- redis_data:/data - redis_data:/data
command: redis-server --appendonly yes command: redis-server --appendonly yes
# ============================================== # ==============================================
# 数据库管理 (Adminer) # 数据库管理 (Adminer) - 仅开发环境
# ============================================== # ==============================================
adminer: adminer:
image: adminer image: adminer
container_name: dreamweaver_adminer container_name: dreamweaver_adminer
restart: always restart: unless-stopped
ports: ports:
- "52999:8080" # Adminer UI - "52999:8080"
depends_on: depends_on:
- db - db
profiles:
- dev # 仅在 --profile dev 时启动
volumes: volumes:
postgres_data: postgres_data: