Refactor Phase 2: Split stories.py into Schema/Service/Controller, add missing endpoints, fix async bug

This commit is contained in:
zhangtuo
2026-02-10 17:14:54 +08:00
parent c351d16d3e
commit 9cdff18336
13 changed files with 881 additions and 612 deletions

View File

@@ -2,26 +2,32 @@ FROM python:3.11-slim
WORKDIR /app WORKDIR /app
# 安装系统依赖 (如果需要) # 设置环境变量
# RUN apt-get update && apt-get install -y gcc libpq-dev && rm -rf /var/lib/apt/lists/* ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1
# 复制项目文件 # 安装系统工具 (curl用于可能的健康检查)
RUN apt-get update && apt-get install -y --no-install-recommends curl && rm -rf /var/lib/apt/lists/*
# 1. 缓存层:仅复制依赖定义并安装
# 创建伪造的 app 目录以满足 pip install . 的要求
COPY pyproject.toml . COPY pyproject.toml .
# 复制源码 RUN mkdir app && touch app/__init__.py
RUN pip install --no-cache-dir .
# 2. 源码层:复制真实代码
COPY app ./app COPY app ./app
COPY alembic ./alembic COPY alembic ./alembic
COPY alembic.ini . COPY alembic.ini .
# 安装依赖 # 再次安装本身(不带依赖),确保源码更新被标记为已安装
# 使用 pip 安装当前目录 (.),会自动解析 pyproject.toml RUN pip install --no-cache-dir --no-deps .
RUN pip install --no-cache-dir .
# 创建静态文件目录 (用于存放生成的图片) # 创建静态文件目录
RUN mkdir -p static/images RUN mkdir -p static/images
# 暴露端口 # 暴露端口
EXPOSE 8000 EXPOSE 8000
# 启动命令 # 默认启动命令(可被 docker-compose 覆盖)
# 生产环境建议使用 gunicorn 或 uvicorn --workers
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

View File

@@ -2,136 +2,41 @@
import asyncio import asyncio
import json import json
import time
import uuid import uuid
from typing import AsyncGenerator, Literal from typing import AsyncGenerator
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request, Response
from fastapi.responses import Response
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import require_user from app.core.deps import require_user
from app.core.logging import get_logger from app.core.logging import get_logger
from app.db.database import get_db
from app.db.models import ChildProfile, Story, StoryUniverse, User
from app.services.provider_router import (
generate_image,
generate_story_content,
generate_storybook,
text_to_speech,
)
from app.core.rate_limiter import check_rate_limit from app.core.rate_limiter import check_rate_limit
from app.tasks.achievements import extract_story_achievements from app.db.database import get_db
from app.db.models import User
from app.schemas.story_schemas import (
GenerateRequest,
StoryResponse,
FullStoryResponse,
StorybookRequest,
StorybookResponse,
StoryListItem,
AchievementItem,
)
from app.services import story_service
from app.services.memory_service import build_enhanced_memory_context
from app.services.provider_router import (
generate_story_content,
generate_image,
)
logger = get_logger(__name__) logger = get_logger(__name__)
router = APIRouter() router = APIRouter()
MAX_DATA_LENGTH = 2000
MAX_EDU_THEME_LENGTH = 200
MAX_TTS_LENGTH = 4000
RATE_LIMIT_WINDOW = 60 # seconds RATE_LIMIT_WINDOW = 60 # seconds
RATE_LIMIT_REQUESTS = 10 RATE_LIMIT_REQUESTS = 10
class GenerateRequest(BaseModel):
"""Story generation request."""
type: Literal["keywords", "full_story"]
data: str = Field(..., min_length=1, max_length=MAX_DATA_LENGTH)
education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH)
child_profile_id: str | None = None
universe_id: str | None = None
class StoryResponse(BaseModel):
"""Story response."""
id: int
title: str
story_text: str
cover_prompt: str | None
image_url: str | None
mode: str
child_profile_id: str | None = None
universe_id: str | None = None
class StoryListItem(BaseModel):
"""Story list item."""
id: int
title: str
image_url: str | None
created_at: str
mode: str
class FullStoryResponse(BaseModel):
"""完整故事响应(含图片和音频状态)。"""
id: int
title: str
story_text: str
cover_prompt: str | None
image_url: str | None
audio_ready: bool
mode: str
errors: dict[str, str | None] = Field(default_factory=dict)
child_profile_id: str | None = None
universe_id: str | None = None
from app.services.memory_service import build_enhanced_memory_context
async def _validate_profile_and_universe(
request: GenerateRequest,
user: User,
db: AsyncSession,
) -> tuple[str | None, str | None]:
if not request.child_profile_id and not request.universe_id:
return None, None
profile_id = request.child_profile_id
universe_id = request.universe_id
if profile_id:
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="孩子档案不存在")
if universe_id:
result = await db.execute(
select(StoryUniverse)
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
.where(
StoryUniverse.id == universe_id,
ChildProfile.user_id == user.id,
)
)
universe = result.scalar_one_or_none()
if not universe:
raise HTTPException(status_code=404, detail="故事宇宙不存在")
if profile_id and universe.child_profile_id != profile_id:
raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配")
if not profile_id:
profile_id = universe.child_profile_id
return profile_id, universe_id
@router.post("/stories/generate", response_model=StoryResponse) @router.post("/stories/generate", response_model=StoryResponse)
async def generate_story( async def generate_story(
request: GenerateRequest, request: GenerateRequest,
@@ -140,47 +45,7 @@ async def generate_story(
): ):
"""Generate or enhance a story.""" """Generate or enhance a story."""
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW) await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
profile_id, universe_id = await _validate_profile_and_universe(request, user, db) return await story_service.generate_and_save_story(request, user.id, db)
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
try:
result = await generate_story_content(
input_type=request.type,
data=request.data,
education_theme=request.education_theme,
memory_context=memory_context,
)
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=502, detail="Story generation failed, please try again.")
story = Story(
user_id=user.id,
child_profile_id=profile_id,
universe_id=universe_id,
title=result.title,
story_text=result.story_text,
cover_prompt=result.cover_prompt_suggestion,
mode=result.mode,
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
return StoryResponse(
id=story.id,
title=story.title,
story_text=story.story_text,
cover_prompt=story.cover_prompt,
image_url=story.image_url,
mode=story.mode,
child_profile_id=story.child_profile_id,
universe_id=story.universe_id,
)
@router.post("/stories/generate/full", response_model=FullStoryResponse) @router.post("/stories/generate/full", response_model=FullStoryResponse)
@@ -189,71 +54,9 @@ async def generate_story_full(
user: User = Depends(require_user), user: User = Depends(require_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
"""生成完整故事(故事 + 并行生成图片和音频)。 """Generate complete story (story + parallel image/audio generation)."""
部分成功策略:故事必须成功,图片/音频失败不影响整体。
"""
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW) await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
profile_id, universe_id = await _validate_profile_and_universe(request, user, db) return await story_service.generate_full_story_service(request, user.id, db)
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
# Step 1: 故事生成(必须成功)
try:
result = await generate_story_content(
input_type=request.type,
data=request.data,
education_theme=request.education_theme,
memory_context=memory_context,
)
except Exception as exc:
logger.error("story_generation_failed", error=str(exc))
raise HTTPException(status_code=502, detail="Story generation failed, please try again.")
# 保存故事
story = Story(
user_id=user.id,
child_profile_id=profile_id,
universe_id=universe_id,
title=result.title,
story_text=result.story_text,
cover_prompt=result.cover_prompt_suggestion,
mode=result.mode,
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
# Step 2: 生成封面图片(音频按需生成,避免浪费)
errors: dict[str, str | None] = {}
image_url: str | None = None
if story.cover_prompt:
try:
image_url = await generate_image(story.cover_prompt)
story.image_url = image_url
await db.commit()
except Exception as exc:
errors["image"] = str(exc)
logger.warning("image_generation_failed", story_id=story.id, error=str(exc))
# 注意:音频不在此处预生成,用户通过 /api/audio/{id} 按需获取
# 这样避免生成后丢弃造成的成本浪费
return FullStoryResponse(
id=story.id,
title=story.title,
story_text=story.story_text,
cover_prompt=story.cover_prompt,
image_url=image_url,
audio_ready=False, # 音频需要用户主动请求
mode=story.mode,
errors=errors,
child_profile_id=story.child_profile_id,
universe_id=story.universe_id,
)
@router.post("/stories/generate/stream") @router.post("/stories/generate/stream")
@@ -263,53 +66,39 @@ async def generate_story_stream(
user: User = Depends(require_user), user: User = Depends(require_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
"""流式生成故事SSE """流式生成故事SSE"""
事件流程:
- started: 返回 story_id
- story_ready: 返回 title, content
- story_failed: 返回 error
- image_ready: 返回 image_url
- image_failed: 返回 error
- complete: 结束流
"""
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW) await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
# Validation
profile_id, universe_id = await story_service.validate_profile_and_universe(
request.child_profile_id, request.universe_id, user.id, db
)
# Build Context
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db) memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
async def event_generator() -> AsyncGenerator[dict, None]: async def event_generator() -> AsyncGenerator[dict, None]:
story_id = str(uuid.uuid4()) story_id = str(uuid.uuid4())
yield {"event": "started", "data": json.dumps({"story_id": story_id})} yield {"event": "started", "data": json.dumps({"story_id": story_id})}
# Step 1: 生成故事 # Step 1: Generate Content
try: try:
result = await generate_story_content( result = await generate_story_content(
input_type=request.type, input_type=request.type,
data=request.data, data=request.data,
education_theme=request.education_theme, education_theme=request.education_theme,
memory_context=memory_context, memory_context=memory_context,
db=db,
) )
except Exception as e: except Exception as e:
logger.error("sse_story_generation_failed", error=str(e)) logger.error("sse_story_generation_failed", error=str(e))
yield {"event": "story_failed", "data": json.dumps({"error": str(e)})} yield {"event": "story_failed", "data": json.dumps({"error": str(e)})}
return return
# 保存故事 # Save Story
story = Story( story = await story_service.create_story_from_result(
user_id=user.id, result, user.id, profile_id, universe_id, db
child_profile_id=profile_id,
universe_id=universe_id,
title=result.title,
story_text=result.story_text,
cover_prompt=result.cover_prompt_suggestion,
mode=result.mode,
) )
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
yield { yield {
"event": "story_ready", "event": "story_ready",
@@ -324,10 +113,11 @@ async def generate_story_stream(
}), }),
} }
# Step 2: 并行生成图片(音频按需) # Step 2: Generate Image
if story.cover_prompt: if story.cover_prompt:
try: try:
image_url = await generate_image(story.cover_prompt) # Direct call to provider router's generate_image, sharing db session
image_url = await generate_image(story.cover_prompt, db=db)
story.image_url = image_url story.image_url = image_url
await db.commit() await db.commit()
yield {"event": "image_ready", "data": json.dumps({"image_url": image_url})} yield {"event": "image_ready", "data": json.dumps({"image_url": image_url})}
@@ -340,217 +130,71 @@ async def generate_story_stream(
return EventSourceResponse(event_generator()) return EventSourceResponse(event_generator())
# ==================== Storybook API ====================
class StorybookRequest(BaseModel):
"""Storybook 生成请求。"""
keywords: str = Field(..., min_length=1, max_length=200)
page_count: int = Field(default=6, ge=4, le=12)
education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH)
generate_images: bool = Field(default=False, description="是否同时生成插图")
child_profile_id: str | None = None
universe_id: str | None = None
class StorybookPageResponse(BaseModel):
"""故事书单页响应。"""
page_number: int
text: str
image_prompt: str
image_url: str | None = None
class StorybookResponse(BaseModel):
"""故事书响应。"""
id: int | None = None
title: str
main_character: str
art_style: str
pages: list[StorybookPageResponse]
cover_prompt: str
cover_url: str | None = None
@router.post("/storybook/generate", response_model=StorybookResponse) @router.post("/storybook/generate", response_model=StorybookResponse)
async def generate_storybook_api( async def generate_storybook_api(
request: StorybookRequest, request: StorybookRequest,
user: User = Depends(require_user), user: User = Depends(require_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
"""生成分页故事书并保存。 """Generate storybook."""
返回故事书结构,包含每页文字和图像提示词。
"""
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW) await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
return await story_service.generate_storybook_service(request, user.id, db)
# 验证档案和宇宙
# 复用 _validate_profile_and_universe 需要将 request 转换为 GenerateRequest 或稍微修改验证函数
# 这里我们直接手动验证,或重构验证函数。为了简单,手动调用部分逻辑。
# 构建临时的 GenerateRequest 用于验证验证函数签名(或者直接手动查库更好)
profile_id = request.child_profile_id
universe_id = request.universe_id
if profile_id:
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
if not result.scalar_one_or_none():
raise HTTPException(status_code=404, detail="孩子档案不存在")
if universe_id:
result = await db.execute(
select(StoryUniverse)
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
.where(
StoryUniverse.id == universe_id,
ChildProfile.user_id == user.id,
)
)
universe = result.scalar_one_or_none()
if not universe:
raise HTTPException(status_code=404, detail="故事宇宙不存在")
if profile_id and universe.child_profile_id != profile_id:
raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配")
if not profile_id:
profile_id = universe.child_profile_id
logger.info(
"storybook_request",
user_id=user.id,
keywords=request.keywords,
page_count=request.page_count,
profile_id=profile_id,
universe_id=universe_id,
)
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
try:
# 注意generate_storybook 目前可能不支持记忆上下文注入
# 我们需要看看 generate_storybook 的签名
# 如果不支持,记忆功能在绘本模式下暂不可用,但基本参数传递是支持的
storybook = await generate_storybook(
keywords=request.keywords,
page_count=request.page_count,
education_theme=request.education_theme,
memory_context=memory_context,
db=db,
)
except Exception as e:
logger.error("storybook_generation_failed", error=str(e))
raise HTTPException(status_code=500, detail=f"故事书生成失败: {e}")
# ==============================================================================
# 核心升级: 并行全量生成 (Parallel Full Rendering)
# ==============================================================================
final_cover_url = storybook.cover_url
if request.generate_images:
logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages))
# 1. 准备所有生图任务 (封面 + 所有内页)
tasks = []
# 封面任务
async def _gen_cover():
if storybook.cover_prompt and not storybook.cover_url:
try:
return await generate_image(storybook.cover_prompt, db=db)
except Exception as e:
logger.warning("cover_gen_failed", error=str(e))
return storybook.cover_url
tasks.append(_gen_cover())
# 内页任务
async def _gen_page(page):
if page.image_prompt and not page.image_url:
try:
url = await generate_image(page.image_prompt, db=db)
page.image_url = url
except Exception as e:
logger.warning("page_gen_failed", page=page.page_number, error=str(e))
for page in storybook.pages:
tasks.append(_gen_page(page))
# 2. 并发执行所有任务
# 使用 return_exceptions=True 防止单张失败影响整体
results = await asyncio.gather(*tasks, return_exceptions=True)
# 3. 更新封面结果 (results[0] 是封面任务的返回值)
cover_res = results[0]
if isinstance(cover_res, str):
final_cover_url = cover_res
logger.info("storybook_parallel_generation_complete")
# ==============================================================================
# 构建并保存 Story 对象
# 将 pages 对象转换为字典列表以存入 JSON 字段
pages_data = [
{
"page_number": p.page_number,
"text": p.text,
"image_prompt": p.image_prompt,
"image_url": p.image_url,
}
for p in storybook.pages
]
story = Story(
user_id=user.id,
child_profile_id=profile_id,
universe_id=universe_id,
title=storybook.title,
mode="storybook",
pages=pages_data, # 存入 JSON 字段
story_text=None, # 绘本模式下,主文本可为空,或者可以存个摘要
cover_prompt=storybook.cover_prompt,
image_url=final_cover_url,
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
# 构建响应 (使用更新后的 pages_data)
response_pages = [
StorybookPageResponse(
page_number=p["page_number"],
text=p["text"],
image_prompt=p["image_prompt"],
image_url=p.get("image_url"),
)
for p in pages_data
]
return StorybookResponse(
id=story.id,
title=storybook.title,
main_character=storybook.main_character,
art_style=storybook.art_style,
pages=response_pages,
cover_prompt=storybook.cover_prompt,
cover_url=final_cover_url,
)
class AchievementItem(BaseModel): # ==================== Missing Endpoints (Issue #5) ====================
type: str
description: str @router.get("/stories", response_model=list[StoryListItem])
obtained_at: str | None = None async def list_stories(
limit: int = 20,
offset: int = 0,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""List stories."""
return await story_service.list_stories(user.id, limit, offset, db)
@router.get("/stories/{story_id}", response_model=StoryResponse)
async def get_story(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get story detail."""
return await story_service.get_story_detail(story_id, user.id, db)
@router.delete("/stories/{story_id}")
async def delete_story(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Delete story."""
await story_service.delete_story(story_id, user.id, db)
return {"message": "Deleted"}
@router.post("/image/generate/{story_id}")
async def generate_story_image(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Generate cover image for story."""
url = await story_service.generate_story_cover(story_id, user.id, db)
return {"image_url": url}
@router.get("/audio/{story_id}")
async def get_story_audio(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get story audio (MP3)."""
audio_bytes = await story_service.generate_story_audio(story_id, user.id, db)
return Response(content=audio_bytes, media_type="audio/mpeg")
@router.get("/stories/{story_id}/achievements", response_model=list[AchievementItem]) @router.get("/stories/{story_id}/achievements", response_model=list[AchievementItem])
@@ -559,32 +203,5 @@ async def get_story_achievements(
user: User = Depends(require_user), user: User = Depends(require_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
"""Get achievements unlocked by a specific story.""" """Get story achievements."""
# 使用 joinedload 避免 N+1 查询 return await story_service.get_story_achievements(story_id, user.id, db)
result = await db.execute(
select(Story)
.options(joinedload(Story.story_universe))
.where(Story.id == story_id, Story.user_id == user.id)
)
story = result.scalar_one_or_none()
if not story:
raise HTTPException(status_code=404, detail="Story not found")
if not story.universe_id or not story.story_universe:
return []
universe = story.story_universe
if not universe.achievements:
return []
results = []
for ach in universe.achievements:
if isinstance(ach, dict) and ach.get("source_story_id") == story_id:
results.append(AchievementItem(
type=ach.get("type", "Unknown"),
description=ach.get("description", ""),
obtained_at=ach.get("obtained_at")
))
return results

View File

@@ -53,6 +53,9 @@ class Settings(BaseSettings):
celery_broker_url: str = Field("redis://localhost:6379/0") celery_broker_url: str = Field("redis://localhost:6379/0")
celery_result_backend: str = Field("redis://localhost:6379/0") celery_result_backend: str = Field("redis://localhost:6379/0")
# Generic Redis
redis_url: str = Field("redis://localhost:6379/0", description="Redis connection URL")
# Admin console # Admin console
enable_admin_console: bool = False enable_admin_console: bool = False
admin_username: str = "admin" admin_username: str = "admin"

25
backend/app/core/redis.py Normal file
View File

@@ -0,0 +1,25 @@
"""Redis client module."""
from typing import AsyncGenerator
from redis.asyncio import Redis, from_url
from app.core.config import settings
_redis_pool: Redis | None = None
async def get_redis() -> Redis:
"""Get global Redis client instance."""
global _redis_pool
if _redis_pool is None:
_redis_pool = from_url(settings.redis_url, encoding="utf-8", decode_responses=True)
return _redis_pool
async def close_redis():
"""Close Redis connection."""
global _redis_pool
if _redis_pool:
await _redis_pool.close()
_redis_pool = None

View File

@@ -0,0 +1 @@
"""故事相关 Schema 模块。"""

View File

@@ -0,0 +1,106 @@
"""故事相关 Pydantic 模型。"""
from datetime import datetime
from typing import Literal
from pydantic import BaseModel, Field
MAX_DATA_LENGTH = 2000
MAX_EDU_THEME_LENGTH = 200
MAX_TTS_LENGTH = 4000
# ==================== 故事模型 ====================
class GenerateRequest(BaseModel):
"""Story generation request."""
type: Literal["keywords", "full_story"]
data: str = Field(..., min_length=1, max_length=MAX_DATA_LENGTH)
education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH)
child_profile_id: str | None = None
universe_id: str | None = None
class StoryResponse(BaseModel):
"""Story response."""
id: int
title: str
story_text: str
cover_prompt: str | None
image_url: str | None
mode: str
child_profile_id: str | None = None
universe_id: str | None = None
class StoryListItem(BaseModel):
"""Story list item."""
id: int
title: str
image_url: str | None
created_at: datetime
mode: str
class FullStoryResponse(BaseModel):
"""完整故事响应(含图片和音频状态)。"""
id: int
title: str
story_text: str
cover_prompt: str | None
image_url: str | None
audio_ready: bool
mode: str
errors: dict[str, str | None] = Field(default_factory=dict)
child_profile_id: str | None = None
universe_id: str | None = None
# ==================== 绘本模型 ====================
class StorybookRequest(BaseModel):
"""Storybook 生成请求。"""
keywords: str = Field(..., min_length=1, max_length=200)
page_count: int = Field(default=6, ge=4, le=12)
education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH)
generate_images: bool = Field(default=False, description="是否同时生成插图")
child_profile_id: str | None = None
universe_id: str | None = None
class StorybookPageResponse(BaseModel):
"""故事书单页响应。"""
page_number: int
text: str
image_prompt: str
image_url: str | None = None
class StorybookResponse(BaseModel):
"""故事书响应。"""
id: int | None = None
title: str
main_character: str
art_style: str
pages: list[StorybookPageResponse]
cover_prompt: str
cover_url: str | None = None
# ==================== 成就模型 ====================
class AchievementItem(BaseModel):
type: str
description: str
obtained_at: str | None = None

View File

@@ -1,31 +1,109 @@
"""In-memory cache for providers loaded from DB.""" """Redis-backed cache for providers loaded from DB."""
import json
from collections import defaultdict from collections import defaultdict
from typing import Literal from typing import Literal
from pydantic import BaseModel
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.core.redis import get_redis
from app.db.admin_models import Provider from app.db.admin_models import Provider
logger = get_logger(__name__)
ProviderType = Literal["text", "image", "tts", "storybook"] ProviderType = Literal["text", "image", "tts", "storybook"]
_cache: dict[ProviderType, list[Provider]] = defaultdict(list)
class CachedProvider(BaseModel):
"""Serializable provider configuration matching DB model fields."""
id: str
name: str
type: str
adapter: str
model: str | None = None
api_base: str | None = None
api_key: str | None = None
timeout_ms: int = 60000
max_retries: int = 1
weight: int = 1
priority: int = 0
enabled: bool = True
config_json: dict | None = None
config_ref: str | None = None
async def reload_providers(db: AsyncSession): # Local memory fallback (L1 cache)
result = await db.execute(select(Provider).where(Provider.enabled == True)) # noqa: E712 _local_cache: dict[ProviderType, list[CachedProvider]] = defaultdict(list)
providers = result.scalars().all() CACHE_KEY = "dreamweaver:providers:config"
grouped: dict[ProviderType, list[Provider]] = defaultdict(list)
for p in providers:
grouped[p.type].append(p)
# sort by priority desc, then weight desc
for k in grouped:
grouped[k].sort(key=lambda x: (x.priority, x.weight), reverse=True)
_cache.clear()
_cache.update(grouped)
return _cache
def get_providers(provider_type: ProviderType) -> list[Provider]: async def reload_providers(db: AsyncSession) -> dict[ProviderType, list[CachedProvider]]:
return _cache.get(provider_type, []) """Reload providers from DB and update Redis cache."""
try:
result = await db.execute(select(Provider).where(Provider.enabled == True)) # noqa: E712
providers = result.scalars().all()
# Convert to Pydantic models
cached_list = []
for p in providers:
cached_list.append(CachedProvider(
id=p.id,
name=p.name,
type=p.type,
adapter=p.adapter,
model=p.model,
api_base=p.api_base,
api_key=p.api_key,
timeout_ms=p.timeout_ms,
max_retries=p.max_retries,
weight=p.weight,
priority=p.priority,
enabled=p.enabled,
config_json=p.config_json,
config_ref=p.config_ref
))
# Group by type
grouped: dict[str, list[CachedProvider]] = defaultdict(list)
for cp in cached_list:
grouped[cp.type].append(cp)
# Sort
for k in grouped:
grouped[k].sort(key=lambda x: (x.priority, x.weight), reverse=True)
# Update Redis
redis = await get_redis()
# Serialize entire dict structure
# Pydantic -> dict -> json
json_data = {k: [p.model_dump() for p in v] for k, v in grouped.items()}
await redis.set(CACHE_KEY, json.dumps(json_data))
# Update local cache
_local_cache.clear()
_local_cache.update(grouped)
return grouped
except Exception as e:
logger.error("failed_to_reload_providers", error=str(e))
raise
async def get_providers(provider_type: ProviderType) -> list[CachedProvider]:
"""Get providers from Redis (preferred) or local fallback."""
try:
redis = await get_redis()
data = await redis.get(CACHE_KEY)
if data:
raw_dict = json.loads(data)
if provider_type in raw_dict:
return [CachedProvider(**item) for item in raw_dict[provider_type]]
return []
except Exception as e:
logger.warning("redis_cache_read_failed", error=str(e))
# Fallback to local memory
return _local_cache.get(provider_type, [])

View File

@@ -177,7 +177,7 @@ def _build_config_from_provider(provider: "Provider") -> AdapterConfig:
) )
def _get_providers_with_config( async def _get_providers_with_config(
provider_type: ProviderType, provider_type: ProviderType,
) -> list[tuple[str, AdapterConfig, "Provider | None"]]: ) -> list[tuple[str, AdapterConfig, "Provider | None"]]:
"""获取供应商列表及其配置。 """获取供应商列表及其配置。
@@ -185,7 +185,7 @@ def _get_providers_with_config(
Returns: Returns:
[(adapter_name, config, provider_or_none), ...] 按优先级排序 [(adapter_name, config, provider_or_none), ...] 按优先级排序
""" """
db_providers = get_providers(provider_type) db_providers = await get_providers(provider_type)
if db_providers: if db_providers:
return [(p.adapter, _build_config_from_provider(p), p) for p in db_providers] return [(p.adapter, _build_config_from_provider(p), p) for p in db_providers]
@@ -265,7 +265,7 @@ async def _route_with_failover(
user_id: 用户 ID可选用于成本追踪和预算检查 user_id: 用户 ID可选用于成本追踪和预算检查
**kwargs: 传递给适配器的参数 **kwargs: 传递给适配器的参数
""" """
providers = _get_providers_with_config(provider_type) providers = await _get_providers_with_config(provider_type)
if not providers: if not providers:
raise ValueError(f"No {provider_type} providers configured.") raise ValueError(f"No {provider_type} providers configured.")

View File

@@ -0,0 +1,434 @@
"""Story business logic service."""
import asyncio
import json
import uuid
from typing import Literal
from fastapi import HTTPException
from sqlalchemy import select, desc
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.core.logging import get_logger
from app.db.models import ChildProfile, Story, StoryUniverse
from app.schemas.story_schemas import (
GenerateRequest,
StorybookRequest,
FullStoryResponse,
StorybookResponse,
StorybookPageResponse,
AchievementItem,
)
from app.services.memory_service import build_enhanced_memory_context
from app.services.provider_router import (
generate_story_content,
generate_image,
generate_storybook,
)
from app.tasks.achievements import extract_story_achievements
logger = get_logger(__name__)
async def validate_profile_and_universe(
profile_id: str | None,
universe_id: str | None,
user_id: str,
db: AsyncSession,
) -> tuple[str | None, str | None]:
"""Validate child profile and universe ownership/relationship."""
if not profile_id and not universe_id:
return None, None
if profile_id:
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user_id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="孩子档案不存在")
if universe_id:
result = await db.execute(
select(StoryUniverse)
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
.where(
StoryUniverse.id == universe_id,
ChildProfile.user_id == user_id,
)
)
universe = result.scalar_one_or_none()
if not universe:
raise HTTPException(status_code=404, detail="故事宇宙不存在")
if profile_id and universe.child_profile_id != profile_id:
raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配")
if not profile_id:
profile_id = universe.child_profile_id
return profile_id, universe_id
async def generate_and_save_story(
request: GenerateRequest,
user_id: str,
db: AsyncSession,
) -> Story:
"""Generate generic story content and save to DB."""
# 1. Validate
profile_id, universe_id = await validate_profile_and_universe(
request.child_profile_id, request.universe_id, user_id, db
)
# 2. Build Context
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
# 3. Generate
try:
result = await generate_story_content(
input_type=request.type,
data=request.data,
education_theme=request.education_theme,
memory_context=memory_context,
db=db,
)
except Exception as exc:
raise HTTPException(status_code=502, detail="Story generation failed, please try again.") from exc
# 4. Save
story = Story(
user_id=user_id,
child_profile_id=profile_id,
universe_id=universe_id,
title=result.title,
story_text=result.story_text,
cover_prompt=result.cover_prompt_suggestion,
mode=result.mode,
)
db.add(story)
await db.commit()
await db.refresh(story)
# 5. Trigger Async Tasks
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
return story
async def generate_full_story_service(
request: GenerateRequest,
user_id: str,
db: AsyncSession,
) -> FullStoryResponse:
"""Generate story with parallel image generation."""
# 1. Generate text part
# We can reuse logic or call generate_story_content directly if we want finer control
# reusing generate_and_save_story to ensure consistency (it handles validation + saving)
story = await generate_and_save_story(request, user_id, db)
# 2. Generate Image (Parallel/Async step in this flow)
image_url: str | None = None
errors: dict[str, str | None] = {}
if story.cover_prompt:
try:
image_url = await generate_image(story.cover_prompt, db=db)
story.image_url = image_url
await db.commit()
except Exception as exc:
errors["image"] = str(exc)
logger.warning("image_generation_failed", story_id=story.id, error=str(exc))
return FullStoryResponse(
id=story.id,
title=story.title,
story_text=story.story_text,
cover_prompt=story.cover_prompt,
image_url=image_url,
audio_ready=False,
mode=story.mode,
errors=errors,
child_profile_id=story.child_profile_id,
universe_id=story.universe_id,
)
async def generate_storybook_service(
request: StorybookRequest,
user_id: str,
db: AsyncSession,
) -> StorybookResponse:
"""Generate storybook with parallel image generation for pages."""
# 1. Validate
profile_id, universe_id = await validate_profile_and_universe(
request.child_profile_id, request.universe_id, user_id, db
)
logger.info(
"storybook_request",
user_id=user_id,
keywords=request.keywords,
page_count=request.page_count,
profile_id=profile_id,
universe_id=universe_id,
)
# 2. Context
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
# 3. Generate Text Structure
try:
storybook = await generate_storybook(
keywords=request.keywords,
page_count=request.page_count,
education_theme=request.education_theme,
memory_context=memory_context,
db=db,
)
except Exception as e:
logger.error("storybook_generation_failed", error=str(e))
raise HTTPException(status_code=500, detail=f"故事书生成失败: {e}")
# 4. Parallel Image Generation
final_cover_url = storybook.cover_url
if request.generate_images:
logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages))
tasks = []
# Cover Task
async def _gen_cover():
if storybook.cover_prompt and not storybook.cover_url:
try:
return await generate_image(storybook.cover_prompt, db=db)
except Exception as e:
logger.warning("cover_gen_failed", error=str(e))
return storybook.cover_url
tasks.append(_gen_cover())
# Page Tasks
async def _gen_page(page):
if page.image_prompt and not page.image_url:
try:
url = await generate_image(page.image_prompt, db=db)
page.image_url = url
except Exception as e:
logger.warning("page_gen_failed", page=page.page_number, error=str(e))
for page in storybook.pages:
tasks.append(_gen_page(page))
# Execute
results = await asyncio.gather(*tasks, return_exceptions=True)
# Update cover result
cover_res = results[0]
if isinstance(cover_res, str):
final_cover_url = cover_res
logger.info("storybook_parallel_generation_complete")
# 5. Save to DB
pages_data = [
{
"page_number": p.page_number,
"text": p.text,
"image_prompt": p.image_prompt,
"image_url": p.image_url,
}
for p in storybook.pages
]
story = Story(
user_id=user_id,
child_profile_id=profile_id,
universe_id=universe_id,
title=storybook.title,
mode="storybook",
pages=pages_data,
story_text=None,
cover_prompt=storybook.cover_prompt,
image_url=final_cover_url,
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
# 6. Build Response
response_pages = [
StorybookPageResponse(
page_number=p["page_number"],
text=p["text"],
image_prompt=p["image_prompt"],
image_url=p.get("image_url"),
)
for p in pages_data
]
return StorybookResponse(
id=story.id,
title=storybook.title,
main_character=storybook.main_character,
art_style=storybook.art_style,
pages=response_pages,
cover_prompt=storybook.cover_prompt,
cover_url=final_cover_url,
)
# ==================== Missing Endpoints Logic (for Issue #5) ====================
async def list_stories(
user_id: str,
limit: int,
offset: int,
db: AsyncSession,
) -> list[Story]:
"""List stories for user."""
result = await db.execute(
select(Story)
.where(Story.user_id == user_id)
.order_by(desc(Story.created_at))
.offset(offset)
.limit(limit)
)
return result.scalars().all()
async def get_story_detail(
story_id: int,
user_id: str,
db: AsyncSession,
) -> Story:
"""Get story detail."""
result = await db.execute(
select(Story).where(Story.id == story_id, Story.user_id == user_id)
)
story = result.scalar_one_or_none()
if not story:
raise HTTPException(status_code=404, detail="Story not found")
return story
async def delete_story(
story_id: int,
user_id: str,
db: AsyncSession,
) -> None:
"""Delete a story."""
story = await get_story_detail(story_id, user_id, db)
await db.delete(story)
await db.commit()
async def create_story_from_result(
result, # StoryOutput
user_id: str,
profile_id: str | None,
universe_id: str | None,
db: AsyncSession,
) -> Story:
"""Save a generated story to DB (helper for stream endpoint)."""
story = Story(
user_id=user_id,
child_profile_id=profile_id,
universe_id=universe_id,
title=result.title,
story_text=result.story_text,
cover_prompt=result.cover_prompt_suggestion,
mode=result.mode,
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
return story
async def generate_story_cover(
story_id: int,
user_id: str,
db: AsyncSession,
) -> str:
"""Generate cover image for an existing story."""
story = await get_story_detail(story_id, user_id, db)
if not story.cover_prompt:
raise HTTPException(status_code=400, detail="Story has no cover prompt")
try:
image_url = await generate_image(story.cover_prompt, db=db)
story.image_url = image_url
await db.commit()
return image_url
except Exception as e:
logger.error("cover_generation_failed", story_id=story_id, error=str(e))
raise HTTPException(status_code=500, detail=f"Image generation failed: {e}")
async def generate_story_audio(
story_id: int,
user_id: str,
db: AsyncSession,
) -> bytes:
"""Generate audio for a story."""
story = await get_story_detail(story_id, user_id, db)
if not story.story_text:
raise HTTPException(status_code=400, detail="Story has no text")
# TODO: Check if audio is already cached/saved?
# For now, generate on the fly via provider
from app.services.provider_router import text_to_speech
try:
audio_data = await text_to_speech(story.story_text, db=db)
return audio_data
except Exception as e:
logger.error("audio_generation_failed", story_id=story_id, error=str(e))
raise HTTPException(status_code=500, detail=f"Audio generation failed: {e}")
async def get_story_achievements(
story_id: int,
user_id: str,
db: AsyncSession,
) -> list[AchievementItem]:
"""Get achievements unlocked by a specific story."""
result = await db.execute(
select(Story)
.options(joinedload(Story.story_universe))
.where(Story.id == story_id, Story.user_id == user_id)
)
story = result.scalar_one_or_none()
if not story:
raise HTTPException(status_code=404, detail="Story not found")
if not story.universe_id or not story.story_universe:
return []
universe = story.story_universe
if not universe.achievements:
return []
results = []
for ach in universe.achievements:
if isinstance(ach, dict) and ach.get("source_story_id") == story_id:
results.append(AchievementItem(
type=ach.get("type", "Unknown"),
description=ach.get("description", ""),
obtained_at=ach.get("obtained_at")
))
return results

View File

@@ -1,93 +0,0 @@
# 梦语织机 (DreamWeaver) 记忆系统升级 PRD
> 版本: v1.0 | 状态: 规划中 | 优先级: High
## 1. 核心愿景 (Vision)
将当前的"数据存储"升级为有温度的**"情感连接系统"**。
我们不只是在记住数据,而是在**维护孩子与故事世界的关系**。让每一个故事不再是孤立的碎片,而是构建孩子专属"故事宇宙"的砖瓦。
---
## 2. 产品痛点与解决方案
| 用户角色 | 核心痛点 | 解决方案 | 预期价值 |
|---------|---------|---------|---------|
| **孩子** | "上次的小兔子怎么不认识我了?" <br> 故事之间缺乏连续性,只有单次体验。 | **角色一致性与记忆注入** <br> 故事开头主动提及往事,角色性格延续。 | 建立情感依恋,提升沉浸感。 |
| **家长** | "这App除了生成故事还能干嘛" <br> 无法感知产品的长期教育价值。 | **显性化成长轨迹** <br> 词汇量统计、主题变化、成就徽章可视化。 | 提高付费意愿,提供社交货币。 |
| **平台** | 用户用完即走,缺乏留存壁垒。 | **沉没成本与情感资产** <br> 积累的记忆越多,越舍不得离开。 | 提升长期留存率 (LTV)。 |
---
## 3. 功能架构:记忆分层模型
### 3.1 层级 1: 核心档案 (Identity Layer)
*性质:永久、静态、显性*
* **数据**: 姓名、年龄、性别。
* **输入**: 家长在 Onboarding 阶段手动输入。
* **作用**: 决定故事的基础适龄性和称呼。
### 3.2 层级 2: 故事宇宙 (Universe Layer)
*性质:长期、动态积累、半显性*
* **主角设定**: 姓名、性格特征(勇敢/害羞)、外貌特征(戴眼镜/卷发)。
* **常驻配角**: 从随机故事中涌现出的固定伙伴(如"爱吃胡萝卜的松鼠奇奇")。
* **世界观**: 故事发生的背景(魔法森林、未来城市、海底世界)。
* **成就系统**: 孩子获得的虚拟奖励(勇气勋章、小小探险家)。
### 3.3 层级 3: 工作记忆 (Working Memory)
*性质:短期、自动衰减、隐性*
* **关键情节**: 最近 3 个故事的结局和核心冲突。
* **情感标记**: 孩子对特定内容的反应(根据“重播”、“跳过”推断)。
* **新学词汇**: 故事中出现的高级词汇。
---
## 4. 关键功能特性 (Feature Specs)
### 4.1 智能开场白 (Memory Injection)
在生成新故事时Prompt 必须包含一段"记忆唤醒"指令。
* **示例**: "小明,还记得上周我们帮小松鼠找回了松果吗?今天,小松鼠带来了一位新朋友..."
* **策略**: 提取权重最高的 Top 3 记忆注入 Prompt。
### 4.2 成长时间轴 (Growth Timeline)
一个可视化的 H5 页面或 App 模块,以时间轴形式展示里程碑。
* **节点类型**:
* 🌟 **初次相遇**: 创建角色的第一天。
* 📖 **阅读打卡**: 累计阅读 10/50/100 本。
* 🏅 **获得成就**: 获得"诚实勋章"。
* 🧠 **能力解锁**: 第一次阅读"科幻"题材。
### 4.3 成就仪式感 (Achievement Ceremony)
* **触发**: 故事生成并分析后,如果获得新成就。
* **表现**: 弹窗动画 + 音效 + "恭喜获得 [勇气] 徽章"。
* **分享**:允许生成带二维码的成就海报。
---
## 5. 记忆类型扩展 (Memory Types)
| 类型 Key | 描述 | 来源 | 过期策略 |
|---------|------|------|---------|
| `recent_story` | 最近读过的故事梗概 | 阅读事件 | 30天衰减 |
| `favorite_character` | 孩子喜欢的角色 | 重播/高评分 | 长期有效 |
| `scary_element` | 孩子害怕/不喜欢的元素 | 跳过/负反馈 | 长期有效 (避雷) |
| `vocabulary_growth` | 新掌握的词汇 | 故事分析 | 90天衰减 |
| `emotional_highlight` | 高光时刻 (如: 特别开心的情节) | 互动数据 | 60天衰减 |
---
## 6. 实施路线图 (Roadmap)
### Phase 1: 基础建设 (v0.3.0)
* [x] 数据库 `MemoryItem` 表 (已存在)。
* [ ] 扩展 `MemoryItem` 类型字段,支持更多维度。
* [ ] 优化 `_build_memory_context`,支持更自然的 Prompt 注入。
* [ ] 前端:简单的"近期回忆"展示列表。
### Phase 2: 可视化与成就 (v0.4.0)
* [ ] 实现"成就提取器" (Achievement Extractor) 的闭环通知。
* [ ] 前端:开发"我的成就"和"成长时间轴"页面。
* [ ] 增加故事开场白的动态生成逻辑。
### Phase 3: 深度智能 (v0.5.0+)
* [ ] 引入向量数据库,实现基于语义的记忆检索 (不仅是时间最近)。
* [ ] 情感分析模型:分析用户行为推断情感倾向。

View File

@@ -0,0 +1,98 @@
# DreamWeaver 重构实施计划
## 1. 概述
本文档基于对当前架构的深入分析,制定了从稳定性、可维护性到可扩展性的分阶段重构计划。
**目标**
- **短期**:解决单点故障风险,优化开发体验,清理关键技术债。
- **中期**:提升系统高可用能力,增强监控与可观测性。
- **长期**:架构演进,支持大规模并发与复杂业务场景。
---
## 2. 短期优化计划 (1-2周)
**重点**:消除即时风险,提升部署效率。
### 2.1 统一镜像构建 (High Priority)
目前 `backend`, `backend-admin`, `worker`, `celery-beat` 重复构建 4 次,浪费资源且镜像版本可能不一致。
- **Action Items**:
- [ ] 修改 `backend/Dockerfile` 为通用基础镜像。
- [ ] 更新 `docker-compose.yml`,定义 `backend-base` 服务或使用 `image` 标签共享镜像。
- [ ] 确保所有 Python 服务共用同一构建产物,仅启动命令不同。
### 2.2 修复 Provider 缓存与限流 (High Priority)
内存缓存 (`TTLCache`, `_latency_cache`) 在多进程/多实例下失效。
- **Action Items**:
- [ ] 引入 Redis 作为共享缓存后端。
- [ ] 重构 `_load_provider_cache`,将 Provider 配置缓存至 Redis。
- [ ] 重构 `stories.py` 中的限流逻辑,使用 `redis-cell` 或简单的 Redis 计数器替代 `TTLCache`
### 2.3 拆分 `stories.py` (Medium Priority)
`app/api/stories.py` 超过 600 行,包含 API 定义、业务逻辑、验证逻辑,维护困难。
- **Action Items**:
- [ ] 创建 `app/services/story_service.py`迁移生成、润色、PDF生成等核心逻辑。
- [ ] 创建 `app/schemas/story_schema.py`,迁移 Pydantic 模型(`GenerateRequest`, `StoryResponse` 等)。
- [ ] API 层 `stories.py` 仅保留路由定义和依赖注入,调用 Service 层。
---
## 3. 中期优化计划 (1-2月)
**重点**:高可用 (HA) 与系统韧性。
### 3.1 数据库高可用 (Critical)
当前 PostgreSQL 为单点,且 Admin/User 混合使用。
- **Action Items**:
- [ ] 部署 PostgreSQL 主从复制 (Master-Slave)。
- [ ] 配置 `PgBouncer` 或 SQLAlchemy 读写分离,减轻主库压力。
- [ ] 实施数据库自动备份策略 (如 `pg_dump` 定时上传 S3)。
### 3.2 消息队列高可用 (Critical)
Redis 单点故障将导致 Celery 任务全盘停摆。
- **Action Items**:
- [ ] 迁移至 Redis Sentinel 或 Redis Cluster 模式。
- [ ] 更新 Celery 配置以支持 Sentinel/Cluster 连接串。
### 3.3 增强可观测性 (Important)
目前仅有简单的日志,缺乏系统级指标。
- **Action Items**:
- [ ] 集成 Prometheus Client暴露 `/metrics` 端点。
- [ ] 部署 Grafana + Prometheus监控 API 延迟、QPS、Celery 队列积压情况。
- [ ] 完善 `ProviderMetrics`,增加可视化大盘,实时监控 AI 供应商的成本与成功率。
---
## 4. 长期架构演进 (季度规划)
**重点**:业务解耦与规模化。
### 4.1 统一 API 网关
- **当前**前端直连后端端口CORS 配置分散。
- **演进**:引入 Traefik 或 Nginx 作为统一网关管理路由、SSL、全局限流、统一鉴权。
### 4.2 前端工程合并
- **当前**User App 和 Admin Console 是完全独立的两个项目,但在组件和工具链上高度重复。
- **演进**:使用一种 Monorepo 策略或基于路由的单一应用策略,共享组件库和类型定义,减少维护成本。
### 4.3 事件驱动架构完善
- **当前**:部分业务逻辑耦合在 API 中。
- **演进**:扩展事件总线,将“阅读记录”、“成就解锁”、“通知推送”等非核心链路完全异步化,通过 Domain Events 解耦。
---
## 5. 实施路线图
| 阶段 | 时间估算 | 关键里程碑 |
| :--- | :--- | :--- |
| **Phase 1: 基础夯实** | Week 1-2 | Docker 构建优化上线Redis 替代内存缓存。 |
| **Phase 2: 代码重构** | Week 3-4 | `stories.py` 拆分完成Service 层建立。 |
| **Phase 3: 高可用建设** | Month 2 | 数据库与 Redis 实现主备/集群模式。 |
| **Phase 4: 监控体系** | Month 2 | Grafana 监控大盘上线,关键指标报警配置完毕。 |

View File

@@ -121,7 +121,7 @@ def mock_text_provider():
cover_prompt_suggestion="A cute rabbit", cover_prompt_suggestion="A cute rabbit",
) )
with patch("app.api.stories.generate_story_content", new_callable=AsyncMock) as mock: with patch("app.services.story_service.generate_story_content", new_callable=AsyncMock) as mock:
mock.return_value = mock_result mock.return_value = mock_result
yield mock yield mock
@@ -129,7 +129,7 @@ def mock_text_provider():
@pytest.fixture @pytest.fixture
def mock_image_provider(): def mock_image_provider():
"""Mock 图像生成。""" """Mock 图像生成。"""
with patch("app.api.stories.generate_image", new_callable=AsyncMock) as mock: with patch("app.services.story_service.generate_image", new_callable=AsyncMock) as mock:
mock.return_value = "https://example.com/image.png" mock.return_value = "https://example.com/image.png"
yield mock yield mock
@@ -137,7 +137,7 @@ def mock_image_provider():
@pytest.fixture @pytest.fixture
def mock_tts_provider(): def mock_tts_provider():
"""Mock TTS。""" """Mock TTS。"""
with patch("app.api.stories.text_to_speech", new_callable=AsyncMock) as mock: with patch("app.services.provider_router.text_to_speech", new_callable=AsyncMock) as mock:
mock.return_value = b"fake-audio-bytes" mock.return_value = b"fake-audio-bytes"
yield mock yield mock

View File

@@ -57,7 +57,6 @@ class TestStoryGenerate:
assert data["mode"] == "generated" assert data["mode"] == "generated"
@pytest.mark.skip(reason="GET /api/stories (列表) 端点尚未实现")
class TestStoryList: class TestStoryList:
"""故事列表测试。""" """故事列表测试。"""
@@ -94,7 +93,6 @@ class TestStoryList:
assert len(data) == 0 assert len(data) == 0
@pytest.mark.skip(reason="GET /api/stories/{id} (详情) 端点尚未实现")
class TestStoryDetail: class TestStoryDetail:
"""故事详情测试。""" """故事详情测试。"""
@@ -118,7 +116,6 @@ class TestStoryDetail:
assert data["story_text"] == test_story.story_text assert data["story_text"] == test_story.story_text
@pytest.mark.skip(reason="DELETE /api/stories/{id} (删除) 端点尚未实现")
class TestStoryDelete: class TestStoryDelete:
"""故事删除测试。""" """故事删除测试。"""
@@ -168,7 +165,6 @@ class TestRateLimit:
assert "Too many requests" in response.json()["detail"] assert "Too many requests" in response.json()["detail"]
@pytest.mark.skip(reason="POST /api/image/generate/{id} 端点尚未实现")
class TestImageGenerate: class TestImageGenerate:
"""封面图片生成测试。""" """封面图片生成测试。"""
@@ -183,7 +179,6 @@ class TestImageGenerate:
assert response.status_code == 404 assert response.status_code == 404
@pytest.mark.skip(reason="GET /api/audio/{id} 端点尚未实现")
class TestAudio: class TestAudio:
"""语音朗读测试。""" """语音朗读测试。"""
@@ -233,7 +228,7 @@ class TestGenerateFull:
def test_generate_full_image_failure(self, auth_client: TestClient, mock_text_provider): def test_generate_full_image_failure(self, auth_client: TestClient, mock_text_provider):
"""图片生成失败时返回部分成功。""" """图片生成失败时返回部分成功。"""
with patch("app.api.stories.generate_image", new_callable=AsyncMock) as mock_img: with patch("app.services.story_service.generate_image", new_callable=AsyncMock) as mock_img:
mock_img.side_effect = Exception("Image API error") mock_img.side_effect = Exception("Image API error")
response = auth_client.post( response = auth_client.post(
"/api/stories/generate/full", "/api/stories/generate/full",
@@ -261,7 +256,6 @@ class TestGenerateFull:
assert call_kwargs["education_theme"] == "勇气与友谊" assert call_kwargs["education_theme"] == "勇气与友谊"
@pytest.mark.skip(reason="POST /api/image/generate/{id} 端点尚未实现")
class TestImageGenerateSuccess: class TestImageGenerateSuccess:
"""封面图片生成成功测试。""" """封面图片生成成功测试。"""