diff --git a/backend/Dockerfile b/backend/Dockerfile
index 23de17c..3abb42c 100644
--- a/backend/Dockerfile
+++ b/backend/Dockerfile
@@ -2,26 +2,32 @@ FROM python:3.11-slim
WORKDIR /app
-# 安装系统依赖 (如果需要)
-# RUN apt-get update && apt-get install -y gcc libpq-dev && rm -rf /var/lib/apt/lists/*
+# 设置环境变量
+ENV PYTHONDONTWRITEBYTECODE=1 \
+ PYTHONUNBUFFERED=1
-# 复制项目文件
+# 安装系统工具 (curl用于可能的健康检查)
+RUN apt-get update && apt-get install -y --no-install-recommends curl && rm -rf /var/lib/apt/lists/*
+
+# 1. 缓存层:仅复制依赖定义并安装
+# 创建伪造的 app 目录以满足 pip install . 的要求
COPY pyproject.toml .
-# 复制源码
+RUN mkdir app && touch app/__init__.py
+RUN pip install --no-cache-dir .
+
+# 2. 源码层:复制真实代码
COPY app ./app
COPY alembic ./alembic
COPY alembic.ini .
-# 安装依赖
-# 使用 pip 安装当前目录 (.),会自动解析 pyproject.toml
-RUN pip install --no-cache-dir .
+# 再次安装本身(不带依赖),确保源码更新被标记为已安装
+RUN pip install --no-cache-dir --no-deps .
-# 创建静态文件目录 (用于存放生成的图片)
+# 创建静态文件目录
RUN mkdir -p static/images
# 暴露端口
EXPOSE 8000
-# 启动命令
-# 生产环境建议使用 gunicorn 或 uvicorn --workers
+# 默认启动命令(可被 docker-compose 覆盖)
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
diff --git a/backend/app/api/stories.py b/backend/app/api/stories.py
index ec34b65..cd84d74 100644
--- a/backend/app/api/stories.py
+++ b/backend/app/api/stories.py
@@ -2,136 +2,41 @@
import asyncio
import json
-import time
import uuid
-from typing import AsyncGenerator, Literal
+from typing import AsyncGenerator
-from fastapi import APIRouter, Depends, HTTPException, Request
-from fastapi.responses import Response
-from pydantic import BaseModel, Field
-from sqlalchemy import select
-from sqlalchemy.ext.asyncio import AsyncSession
-from sqlalchemy.orm import joinedload
+from fastapi import APIRouter, Depends, HTTPException, Request, Response
from sse_starlette.sse import EventSourceResponse
+from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import require_user
from app.core.logging import get_logger
-from app.db.database import get_db
-from app.db.models import ChildProfile, Story, StoryUniverse, User
-from app.services.provider_router import (
- generate_image,
- generate_story_content,
- generate_storybook,
- text_to_speech,
-)
from app.core.rate_limiter import check_rate_limit
-from app.tasks.achievements import extract_story_achievements
+from app.db.database import get_db
+from app.db.models import User
+from app.schemas.story_schemas import (
+ GenerateRequest,
+ StoryResponse,
+ FullStoryResponse,
+ StorybookRequest,
+ StorybookResponse,
+ StoryListItem,
+ AchievementItem,
+)
+from app.services import story_service
+from app.services.memory_service import build_enhanced_memory_context
+from app.services.provider_router import (
+ generate_story_content,
+ generate_image,
+)
logger = get_logger(__name__)
-
router = APIRouter()
-MAX_DATA_LENGTH = 2000
-MAX_EDU_THEME_LENGTH = 200
-MAX_TTS_LENGTH = 4000
-
RATE_LIMIT_WINDOW = 60 # seconds
RATE_LIMIT_REQUESTS = 10
-class GenerateRequest(BaseModel):
- """Story generation request."""
-
- type: Literal["keywords", "full_story"]
- data: str = Field(..., min_length=1, max_length=MAX_DATA_LENGTH)
- education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH)
- child_profile_id: str | None = None
- universe_id: str | None = None
-
-
-class StoryResponse(BaseModel):
- """Story response."""
-
- id: int
- title: str
- story_text: str
- cover_prompt: str | None
- image_url: str | None
- mode: str
- child_profile_id: str | None = None
- universe_id: str | None = None
-
-
-class StoryListItem(BaseModel):
- """Story list item."""
-
- id: int
- title: str
- image_url: str | None
- created_at: str
- mode: str
-
-
-class FullStoryResponse(BaseModel):
- """完整故事响应(含图片和音频状态)。"""
-
- id: int
- title: str
- story_text: str
- cover_prompt: str | None
- image_url: str | None
- audio_ready: bool
- mode: str
- errors: dict[str, str | None] = Field(default_factory=dict)
- child_profile_id: str | None = None
- universe_id: str | None = None
-
-
-from app.services.memory_service import build_enhanced_memory_context
-
-
-async def _validate_profile_and_universe(
- request: GenerateRequest,
- user: User,
- db: AsyncSession,
-) -> tuple[str | None, str | None]:
- if not request.child_profile_id and not request.universe_id:
- return None, None
-
- profile_id = request.child_profile_id
- universe_id = request.universe_id
-
- if profile_id:
- result = await db.execute(
- select(ChildProfile).where(
- ChildProfile.id == profile_id,
- ChildProfile.user_id == user.id,
- )
- )
- profile = result.scalar_one_or_none()
- if not profile:
- raise HTTPException(status_code=404, detail="孩子档案不存在")
-
- if universe_id:
- result = await db.execute(
- select(StoryUniverse)
- .join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
- .where(
- StoryUniverse.id == universe_id,
- ChildProfile.user_id == user.id,
- )
- )
- universe = result.scalar_one_or_none()
- if not universe:
- raise HTTPException(status_code=404, detail="故事宇宙不存在")
- if profile_id and universe.child_profile_id != profile_id:
- raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配")
- if not profile_id:
- profile_id = universe.child_profile_id
-
- return profile_id, universe_id
-
-
@router.post("/stories/generate", response_model=StoryResponse)
async def generate_story(
request: GenerateRequest,
@@ -140,47 +45,7 @@ async def generate_story(
):
"""Generate or enhance a story."""
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
- profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
- memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
-
- try:
- result = await generate_story_content(
- input_type=request.type,
- data=request.data,
- education_theme=request.education_theme,
- memory_context=memory_context,
- )
- except HTTPException:
- raise
- except Exception:
- raise HTTPException(status_code=502, detail="Story generation failed, please try again.")
-
- story = Story(
- user_id=user.id,
- child_profile_id=profile_id,
- universe_id=universe_id,
- title=result.title,
- story_text=result.story_text,
- cover_prompt=result.cover_prompt_suggestion,
- mode=result.mode,
- )
- db.add(story)
- await db.commit()
- await db.refresh(story)
-
- if universe_id:
- extract_story_achievements.delay(story.id, universe_id)
-
- return StoryResponse(
- id=story.id,
- title=story.title,
- story_text=story.story_text,
- cover_prompt=story.cover_prompt,
- image_url=story.image_url,
- mode=story.mode,
- child_profile_id=story.child_profile_id,
- universe_id=story.universe_id,
- )
+ return await story_service.generate_and_save_story(request, user.id, db)
@router.post("/stories/generate/full", response_model=FullStoryResponse)
@@ -189,71 +54,9 @@ async def generate_story_full(
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
- """生成完整故事(故事 + 并行生成图片和音频)。
-
- 部分成功策略:故事必须成功,图片/音频失败不影响整体。
- """
+ """Generate complete story (story + parallel image/audio generation)."""
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
- profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
- memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
-
- # Step 1: 故事生成(必须成功)
- try:
- result = await generate_story_content(
- input_type=request.type,
- data=request.data,
- education_theme=request.education_theme,
- memory_context=memory_context,
- )
- except Exception as exc:
- logger.error("story_generation_failed", error=str(exc))
- raise HTTPException(status_code=502, detail="Story generation failed, please try again.")
-
- # 保存故事
- story = Story(
- user_id=user.id,
- child_profile_id=profile_id,
- universe_id=universe_id,
- title=result.title,
- story_text=result.story_text,
- cover_prompt=result.cover_prompt_suggestion,
- mode=result.mode,
- )
- db.add(story)
- await db.commit()
- await db.refresh(story)
-
- if universe_id:
- extract_story_achievements.delay(story.id, universe_id)
-
- # Step 2: 生成封面图片(音频按需生成,避免浪费)
- errors: dict[str, str | None] = {}
- image_url: str | None = None
-
- if story.cover_prompt:
- try:
- image_url = await generate_image(story.cover_prompt)
- story.image_url = image_url
- await db.commit()
- except Exception as exc:
- errors["image"] = str(exc)
- logger.warning("image_generation_failed", story_id=story.id, error=str(exc))
-
- # 注意:音频不在此处预生成,用户通过 /api/audio/{id} 按需获取
- # 这样避免生成后丢弃造成的成本浪费
-
- return FullStoryResponse(
- id=story.id,
- title=story.title,
- story_text=story.story_text,
- cover_prompt=story.cover_prompt,
- image_url=image_url,
- audio_ready=False, # 音频需要用户主动请求
- mode=story.mode,
- errors=errors,
- child_profile_id=story.child_profile_id,
- universe_id=story.universe_id,
- )
+ return await story_service.generate_full_story_service(request, user.id, db)
@router.post("/stories/generate/stream")
@@ -263,53 +66,39 @@ async def generate_story_stream(
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
- """流式生成故事(SSE)。
-
- 事件流程:
- - started: 返回 story_id
- - story_ready: 返回 title, content
- - story_failed: 返回 error
- - image_ready: 返回 image_url
- - image_failed: 返回 error
- - complete: 结束流
- """
+ """流式生成故事(SSE)。"""
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
- profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
+
+ # Validation
+ profile_id, universe_id = await story_service.validate_profile_and_universe(
+ request.child_profile_id, request.universe_id, user.id, db
+ )
+
+ # Build Context
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
async def event_generator() -> AsyncGenerator[dict, None]:
story_id = str(uuid.uuid4())
yield {"event": "started", "data": json.dumps({"story_id": story_id})}
- # Step 1: 生成故事
+ # Step 1: Generate Content
try:
result = await generate_story_content(
input_type=request.type,
data=request.data,
education_theme=request.education_theme,
memory_context=memory_context,
+ db=db,
)
except Exception as e:
logger.error("sse_story_generation_failed", error=str(e))
yield {"event": "story_failed", "data": json.dumps({"error": str(e)})}
return
- # 保存故事
- story = Story(
- user_id=user.id,
- child_profile_id=profile_id,
- universe_id=universe_id,
- title=result.title,
- story_text=result.story_text,
- cover_prompt=result.cover_prompt_suggestion,
- mode=result.mode,
+ # Save Story
+ story = await story_service.create_story_from_result(
+ result, user.id, profile_id, universe_id, db
)
- db.add(story)
- await db.commit()
- await db.refresh(story)
-
- if universe_id:
- extract_story_achievements.delay(story.id, universe_id)
yield {
"event": "story_ready",
@@ -324,10 +113,11 @@ async def generate_story_stream(
}),
}
- # Step 2: 并行生成图片(音频按需)
+ # Step 2: Generate Image
if story.cover_prompt:
try:
- image_url = await generate_image(story.cover_prompt)
+ # Direct call to provider router's generate_image, sharing db session
+ image_url = await generate_image(story.cover_prompt, db=db)
story.image_url = image_url
await db.commit()
yield {"event": "image_ready", "data": json.dumps({"image_url": image_url})}
@@ -340,217 +130,71 @@ async def generate_story_stream(
return EventSourceResponse(event_generator())
-# ==================== Storybook API ====================
-
-
-class StorybookRequest(BaseModel):
- """Storybook 生成请求。"""
-
- keywords: str = Field(..., min_length=1, max_length=200)
- page_count: int = Field(default=6, ge=4, le=12)
- education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH)
- generate_images: bool = Field(default=False, description="是否同时生成插图")
- child_profile_id: str | None = None
- universe_id: str | None = None
-
-
-class StorybookPageResponse(BaseModel):
- """故事书单页响应。"""
-
- page_number: int
- text: str
- image_prompt: str
- image_url: str | None = None
-
-
-class StorybookResponse(BaseModel):
- """故事书响应。"""
-
- id: int | None = None
- title: str
- main_character: str
- art_style: str
- pages: list[StorybookPageResponse]
- cover_prompt: str
- cover_url: str | None = None
-
-
@router.post("/storybook/generate", response_model=StorybookResponse)
async def generate_storybook_api(
request: StorybookRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
- """生成分页故事书并保存。
-
- 返回故事书结构,包含每页文字和图像提示词。
- """
+ """Generate storybook."""
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
-
- # 验证档案和宇宙
- # 复用 _validate_profile_and_universe 需要将 request 转换为 GenerateRequest 或稍微修改验证函数
- # 这里我们直接手动验证,或重构验证函数。为了简单,手动调用部分逻辑。
-
- # 构建临时的 GenerateRequest 用于验证验证函数签名(或者直接手动查库更好)
- profile_id = request.child_profile_id
- universe_id = request.universe_id
-
- if profile_id:
- result = await db.execute(
- select(ChildProfile).where(
- ChildProfile.id == profile_id,
- ChildProfile.user_id == user.id,
- )
- )
- if not result.scalar_one_or_none():
- raise HTTPException(status_code=404, detail="孩子档案不存在")
-
- if universe_id:
- result = await db.execute(
- select(StoryUniverse)
- .join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
- .where(
- StoryUniverse.id == universe_id,
- ChildProfile.user_id == user.id,
- )
- )
- universe = result.scalar_one_or_none()
- if not universe:
- raise HTTPException(status_code=404, detail="故事宇宙不存在")
- if profile_id and universe.child_profile_id != profile_id:
- raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配")
- if not profile_id:
- profile_id = universe.child_profile_id
-
- logger.info(
- "storybook_request",
- user_id=user.id,
- keywords=request.keywords,
- page_count=request.page_count,
- profile_id=profile_id,
- universe_id=universe_id,
- )
-
- memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
-
- try:
- # 注意:generate_storybook 目前可能不支持记忆上下文注入
- # 我们需要看看 generate_storybook 的签名
- # 如果不支持,记忆功能在绘本模式下暂不可用,但基本参数传递是支持的
- storybook = await generate_storybook(
- keywords=request.keywords,
- page_count=request.page_count,
- education_theme=request.education_theme,
- memory_context=memory_context,
- db=db,
- )
- except Exception as e:
- logger.error("storybook_generation_failed", error=str(e))
- raise HTTPException(status_code=500, detail=f"故事书生成失败: {e}")
-
- # ==============================================================================
- # 核心升级: 并行全量生成 (Parallel Full Rendering)
- # ==============================================================================
- final_cover_url = storybook.cover_url
-
- if request.generate_images:
- logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages))
-
- # 1. 准备所有生图任务 (封面 + 所有内页)
- tasks = []
-
- # 封面任务
- async def _gen_cover():
- if storybook.cover_prompt and not storybook.cover_url:
- try:
- return await generate_image(storybook.cover_prompt, db=db)
- except Exception as e:
- logger.warning("cover_gen_failed", error=str(e))
- return storybook.cover_url
-
- tasks.append(_gen_cover())
-
- # 内页任务
- async def _gen_page(page):
- if page.image_prompt and not page.image_url:
- try:
- url = await generate_image(page.image_prompt, db=db)
- page.image_url = url
- except Exception as e:
- logger.warning("page_gen_failed", page=page.page_number, error=str(e))
-
- for page in storybook.pages:
- tasks.append(_gen_page(page))
-
- # 2. 并发执行所有任务
- # 使用 return_exceptions=True 防止单张失败影响整体
- results = await asyncio.gather(*tasks, return_exceptions=True)
-
- # 3. 更新封面结果 (results[0] 是封面任务的返回值)
- cover_res = results[0]
- if isinstance(cover_res, str):
- final_cover_url = cover_res
-
- logger.info("storybook_parallel_generation_complete")
-
- # ==============================================================================
-
- # 构建并保存 Story 对象
- # 将 pages 对象转换为字典列表以存入 JSON 字段
- pages_data = [
- {
- "page_number": p.page_number,
- "text": p.text,
- "image_prompt": p.image_prompt,
- "image_url": p.image_url,
- }
- for p in storybook.pages
- ]
-
- story = Story(
- user_id=user.id,
- child_profile_id=profile_id,
- universe_id=universe_id,
- title=storybook.title,
- mode="storybook",
- pages=pages_data, # 存入 JSON 字段
- story_text=None, # 绘本模式下,主文本可为空,或者可以存个摘要
- cover_prompt=storybook.cover_prompt,
- image_url=final_cover_url,
- )
- db.add(story)
- await db.commit()
- await db.refresh(story)
-
- if universe_id:
- extract_story_achievements.delay(story.id, universe_id)
-
- # 构建响应 (使用更新后的 pages_data)
- response_pages = [
- StorybookPageResponse(
- page_number=p["page_number"],
- text=p["text"],
- image_prompt=p["image_prompt"],
- image_url=p.get("image_url"),
- )
- for p in pages_data
- ]
-
- return StorybookResponse(
- id=story.id,
- title=storybook.title,
- main_character=storybook.main_character,
- art_style=storybook.art_style,
- pages=response_pages,
- cover_prompt=storybook.cover_prompt,
- cover_url=final_cover_url,
- )
+ return await story_service.generate_storybook_service(request, user.id, db)
-class AchievementItem(BaseModel):
- type: str
- description: str
- obtained_at: str | None = None
+# ==================== Missing Endpoints (Issue #5) ====================
+
+@router.get("/stories", response_model=list[StoryListItem])
+async def list_stories(
+ limit: int = 20,
+ offset: int = 0,
+ user: User = Depends(require_user),
+ db: AsyncSession = Depends(get_db),
+):
+ """List stories."""
+ return await story_service.list_stories(user.id, limit, offset, db)
+
+
+@router.get("/stories/{story_id}", response_model=StoryResponse)
+async def get_story(
+ story_id: int,
+ user: User = Depends(require_user),
+ db: AsyncSession = Depends(get_db),
+):
+ """Get story detail."""
+ return await story_service.get_story_detail(story_id, user.id, db)
+
+
+@router.delete("/stories/{story_id}")
+async def delete_story(
+ story_id: int,
+ user: User = Depends(require_user),
+ db: AsyncSession = Depends(get_db),
+):
+ """Delete story."""
+ await story_service.delete_story(story_id, user.id, db)
+ return {"message": "Deleted"}
+
+
+@router.post("/image/generate/{story_id}")
+async def generate_story_image(
+ story_id: int,
+ user: User = Depends(require_user),
+ db: AsyncSession = Depends(get_db),
+):
+ """Generate cover image for story."""
+ url = await story_service.generate_story_cover(story_id, user.id, db)
+ return {"image_url": url}
+
+
+@router.get("/audio/{story_id}")
+async def get_story_audio(
+ story_id: int,
+ user: User = Depends(require_user),
+ db: AsyncSession = Depends(get_db),
+):
+ """Get story audio (MP3)."""
+ audio_bytes = await story_service.generate_story_audio(story_id, user.id, db)
+ return Response(content=audio_bytes, media_type="audio/mpeg")
@router.get("/stories/{story_id}/achievements", response_model=list[AchievementItem])
@@ -559,32 +203,5 @@ async def get_story_achievements(
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
- """Get achievements unlocked by a specific story."""
- # 使用 joinedload 避免 N+1 查询
- result = await db.execute(
- select(Story)
- .options(joinedload(Story.story_universe))
- .where(Story.id == story_id, Story.user_id == user.id)
- )
- story = result.scalar_one_or_none()
-
- if not story:
- raise HTTPException(status_code=404, detail="Story not found")
-
- if not story.universe_id or not story.story_universe:
- return []
-
- universe = story.story_universe
- if not universe.achievements:
- return []
-
- results = []
- for ach in universe.achievements:
- if isinstance(ach, dict) and ach.get("source_story_id") == story_id:
- results.append(AchievementItem(
- type=ach.get("type", "Unknown"),
- description=ach.get("description", ""),
- obtained_at=ach.get("obtained_at")
- ))
-
- return results
+ """Get story achievements."""
+ return await story_service.get_story_achievements(story_id, user.id, db)
diff --git a/backend/app/core/config.py b/backend/app/core/config.py
index ca038b9..8823b15 100644
--- a/backend/app/core/config.py
+++ b/backend/app/core/config.py
@@ -52,6 +52,9 @@ class Settings(BaseSettings):
# Celery (Redis)
celery_broker_url: str = Field("redis://localhost:6379/0")
celery_result_backend: str = Field("redis://localhost:6379/0")
+
+ # Generic Redis
+ redis_url: str = Field("redis://localhost:6379/0", description="Redis connection URL")
# Admin console
enable_admin_console: bool = False
diff --git a/backend/app/core/redis.py b/backend/app/core/redis.py
new file mode 100644
index 0000000..f8e6962
--- /dev/null
+++ b/backend/app/core/redis.py
@@ -0,0 +1,25 @@
+"""Redis client module."""
+
+from typing import AsyncGenerator
+
+from redis.asyncio import Redis, from_url
+
+from app.core.config import settings
+
+_redis_pool: Redis | None = None
+
+
+async def get_redis() -> Redis:
+ """Get global Redis client instance."""
+ global _redis_pool
+ if _redis_pool is None:
+ _redis_pool = from_url(settings.redis_url, encoding="utf-8", decode_responses=True)
+ return _redis_pool
+
+
+async def close_redis():
+ """Close Redis connection."""
+ global _redis_pool
+ if _redis_pool:
+ await _redis_pool.close()
+ _redis_pool = None
diff --git a/backend/app/schemas/__init__.py b/backend/app/schemas/__init__.py
new file mode 100644
index 0000000..d2beb0e
--- /dev/null
+++ b/backend/app/schemas/__init__.py
@@ -0,0 +1 @@
+"""故事相关 Schema 模块。"""
diff --git a/backend/app/schemas/story_schemas.py b/backend/app/schemas/story_schemas.py
new file mode 100644
index 0000000..80a8c80
--- /dev/null
+++ b/backend/app/schemas/story_schemas.py
@@ -0,0 +1,106 @@
+"""故事相关 Pydantic 模型。"""
+
+from datetime import datetime
+from typing import Literal
+
+from pydantic import BaseModel, Field
+
+
+MAX_DATA_LENGTH = 2000
+MAX_EDU_THEME_LENGTH = 200
+MAX_TTS_LENGTH = 4000
+
+
+# ==================== 故事模型 ====================
+
+
+class GenerateRequest(BaseModel):
+ """Story generation request."""
+
+ type: Literal["keywords", "full_story"]
+ data: str = Field(..., min_length=1, max_length=MAX_DATA_LENGTH)
+ education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH)
+ child_profile_id: str | None = None
+ universe_id: str | None = None
+
+
+class StoryResponse(BaseModel):
+ """Story response."""
+
+ id: int
+ title: str
+ story_text: str
+ cover_prompt: str | None
+ image_url: str | None
+ mode: str
+ child_profile_id: str | None = None
+ universe_id: str | None = None
+
+
+class StoryListItem(BaseModel):
+ """Story list item."""
+
+ id: int
+ title: str
+ image_url: str | None
+ created_at: datetime
+ mode: str
+
+
+class FullStoryResponse(BaseModel):
+ """完整故事响应(含图片和音频状态)。"""
+
+ id: int
+ title: str
+ story_text: str
+ cover_prompt: str | None
+ image_url: str | None
+ audio_ready: bool
+ mode: str
+ errors: dict[str, str | None] = Field(default_factory=dict)
+ child_profile_id: str | None = None
+ universe_id: str | None = None
+
+
+# ==================== 绘本模型 ====================
+
+
+class StorybookRequest(BaseModel):
+ """Storybook 生成请求。"""
+
+ keywords: str = Field(..., min_length=1, max_length=200)
+ page_count: int = Field(default=6, ge=4, le=12)
+ education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH)
+ generate_images: bool = Field(default=False, description="是否同时生成插图")
+ child_profile_id: str | None = None
+ universe_id: str | None = None
+
+
+class StorybookPageResponse(BaseModel):
+ """故事书单页响应。"""
+
+ page_number: int
+ text: str
+ image_prompt: str
+ image_url: str | None = None
+
+
+class StorybookResponse(BaseModel):
+ """故事书响应。"""
+
+ id: int | None = None
+ title: str
+ main_character: str
+ art_style: str
+ pages: list[StorybookPageResponse]
+ cover_prompt: str
+ cover_url: str | None = None
+
+
+# ==================== 成就模型 ====================
+
+
+class AchievementItem(BaseModel):
+ type: str
+ description: str
+ obtained_at: str | None = None
diff --git a/backend/app/services/provider_cache.py b/backend/app/services/provider_cache.py
index bd59404..e1a0db3 100644
--- a/backend/app/services/provider_cache.py
+++ b/backend/app/services/provider_cache.py
@@ -1,31 +1,109 @@
-"""In-memory cache for providers loaded from DB."""
+"""Redis-backed cache for providers loaded from DB."""
+import json
from collections import defaultdict
from typing import Literal
+from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
+from app.core.logging import get_logger
+from app.core.redis import get_redis
from app.db.admin_models import Provider
+logger = get_logger(__name__)
+
ProviderType = Literal["text", "image", "tts", "storybook"]
-_cache: dict[ProviderType, list[Provider]] = defaultdict(list)
+
+class CachedProvider(BaseModel):
+ """Serializable provider configuration matching DB model fields."""
+ id: str
+ name: str
+ type: str
+ adapter: str
+ model: str | None = None
+ api_base: str | None = None
+ api_key: str | None = None
+ timeout_ms: int = 60000
+ max_retries: int = 1
+ weight: int = 1
+ priority: int = 0
+ enabled: bool = True
+ config_json: dict | None = None
+ config_ref: str | None = None
-async def reload_providers(db: AsyncSession):
- result = await db.execute(select(Provider).where(Provider.enabled == True)) # noqa: E712
- providers = result.scalars().all()
- grouped: dict[ProviderType, list[Provider]] = defaultdict(list)
- for p in providers:
- grouped[p.type].append(p)
- # sort by priority desc, then weight desc
- for k in grouped:
- grouped[k].sort(key=lambda x: (x.priority, x.weight), reverse=True)
- _cache.clear()
- _cache.update(grouped)
- return _cache
+# Local memory fallback (L1 cache)
+_local_cache: dict[ProviderType, list[CachedProvider]] = defaultdict(list)
+CACHE_KEY = "dreamweaver:providers:config"
-def get_providers(provider_type: ProviderType) -> list[Provider]:
- return _cache.get(provider_type, [])
+async def reload_providers(db: AsyncSession) -> dict[ProviderType, list[CachedProvider]]:
+ """Reload providers from DB and update Redis cache."""
+ try:
+ result = await db.execute(select(Provider).where(Provider.enabled == True)) # noqa: E712
+ providers = result.scalars().all()
+
+ # Convert to Pydantic models
+ cached_list = []
+ for p in providers:
+ cached_list.append(CachedProvider(
+ id=p.id,
+ name=p.name,
+ type=p.type,
+ adapter=p.adapter,
+ model=p.model,
+ api_base=p.api_base,
+ api_key=p.api_key,
+ timeout_ms=p.timeout_ms,
+ max_retries=p.max_retries,
+ weight=p.weight,
+ priority=p.priority,
+ enabled=p.enabled,
+ config_json=p.config_json,
+ config_ref=p.config_ref
+ ))
+
+ # Group by type
+ grouped: dict[str, list[CachedProvider]] = defaultdict(list)
+ for cp in cached_list:
+ grouped[cp.type].append(cp)
+
+ # Sort
+ for k in grouped:
+ grouped[k].sort(key=lambda x: (x.priority, x.weight), reverse=True)
+
+ # Update Redis
+ redis = await get_redis()
+ # Serialize entire dict structure
+ # Pydantic -> dict -> json
+ json_data = {k: [p.model_dump() for p in v] for k, v in grouped.items()}
+ await redis.set(CACHE_KEY, json.dumps(json_data))
+
+ # Update local cache
+ _local_cache.clear()
+ _local_cache.update(grouped)
+ return grouped
+
+ except Exception as e:
+ logger.error("failed_to_reload_providers", error=str(e))
+ raise
+
+
+async def get_providers(provider_type: ProviderType) -> list[CachedProvider]:
+ """Get providers from Redis (preferred) or local fallback."""
+ try:
+ redis = await get_redis()
+ data = await redis.get(CACHE_KEY)
+ if data:
+ raw_dict = json.loads(data)
+ if provider_type in raw_dict:
+ return [CachedProvider(**item) for item in raw_dict[provider_type]]
+ return []
+ except Exception as e:
+ logger.warning("redis_cache_read_failed", error=str(e))
+
+ # Fallback to local memory
+ return _local_cache.get(provider_type, [])
diff --git a/backend/app/services/provider_router.py b/backend/app/services/provider_router.py
index 259a152..d49947a 100644
--- a/backend/app/services/provider_router.py
+++ b/backend/app/services/provider_router.py
@@ -177,7 +177,7 @@ def _build_config_from_provider(provider: "Provider") -> AdapterConfig:
)
-def _get_providers_with_config(
+async def _get_providers_with_config(
provider_type: ProviderType,
) -> list[tuple[str, AdapterConfig, "Provider | None"]]:
"""获取供应商列表及其配置。
@@ -185,7 +185,7 @@ def _get_providers_with_config(
Returns:
[(adapter_name, config, provider_or_none), ...] 按优先级排序
"""
- db_providers = get_providers(provider_type)
+ db_providers = await get_providers(provider_type)
if db_providers:
return [(p.adapter, _build_config_from_provider(p), p) for p in db_providers]
@@ -265,7 +265,7 @@ async def _route_with_failover(
user_id: 用户 ID(可选,用于成本追踪和预算检查)
**kwargs: 传递给适配器的参数
"""
- providers = _get_providers_with_config(provider_type)
+ providers = await _get_providers_with_config(provider_type)
if not providers:
raise ValueError(f"No {provider_type} providers configured.")
diff --git a/backend/app/services/story_service.py b/backend/app/services/story_service.py
new file mode 100644
index 0000000..838275f
--- /dev/null
+++ b/backend/app/services/story_service.py
@@ -0,0 +1,434 @@
+"""Story business logic service."""
+
+import asyncio
+import json
+import uuid
+from typing import Literal
+
+from fastapi import HTTPException
+from sqlalchemy import select, desc
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.orm import joinedload
+
+from app.core.logging import get_logger
+from app.db.models import ChildProfile, Story, StoryUniverse
+from app.schemas.story_schemas import (
+ GenerateRequest,
+ StorybookRequest,
+ FullStoryResponse,
+ StorybookResponse,
+ StorybookPageResponse,
+ AchievementItem,
+)
+from app.services.memory_service import build_enhanced_memory_context
+from app.services.provider_router import (
+ generate_story_content,
+ generate_image,
+ generate_storybook,
+)
+from app.tasks.achievements import extract_story_achievements
+
+logger = get_logger(__name__)
+
+
+async def validate_profile_and_universe(
+ profile_id: str | None,
+ universe_id: str | None,
+ user_id: str,
+ db: AsyncSession,
+) -> tuple[str | None, str | None]:
+ """Validate child profile and universe ownership/relationship."""
+ if not profile_id and not universe_id:
+ return None, None
+
+ if profile_id:
+ result = await db.execute(
+ select(ChildProfile).where(
+ ChildProfile.id == profile_id,
+ ChildProfile.user_id == user_id,
+ )
+ )
+ profile = result.scalar_one_or_none()
+ if not profile:
+ raise HTTPException(status_code=404, detail="孩子档案不存在")
+
+ if universe_id:
+ result = await db.execute(
+ select(StoryUniverse)
+ .join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
+ .where(
+ StoryUniverse.id == universe_id,
+ ChildProfile.user_id == user_id,
+ )
+ )
+ universe = result.scalar_one_or_none()
+ if not universe:
+ raise HTTPException(status_code=404, detail="故事宇宙不存在")
+ if profile_id and universe.child_profile_id != profile_id:
+ raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配")
+ if not profile_id:
+ profile_id = universe.child_profile_id
+
+ return profile_id, universe_id
+
+
+async def generate_and_save_story(
+ request: GenerateRequest,
+ user_id: str,
+ db: AsyncSession,
+) -> Story:
+ """Generate generic story content and save to DB."""
+ # 1. Validate
+ profile_id, universe_id = await validate_profile_and_universe(
+ request.child_profile_id, request.universe_id, user_id, db
+ )
+
+ # 2. Build Context
+ memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
+
+ # 3. Generate
+ try:
+ result = await generate_story_content(
+ input_type=request.type,
+ data=request.data,
+ education_theme=request.education_theme,
+ memory_context=memory_context,
+ db=db,
+ )
+ except Exception as exc:
+ raise HTTPException(status_code=502, detail="Story generation failed, please try again.") from exc
+
+ # 4. Save
+ story = Story(
+ user_id=user_id,
+ child_profile_id=profile_id,
+ universe_id=universe_id,
+ title=result.title,
+ story_text=result.story_text,
+ cover_prompt=result.cover_prompt_suggestion,
+ mode=result.mode,
+ )
+ db.add(story)
+ await db.commit()
+ await db.refresh(story)
+
+ # 5. Trigger Async Tasks
+ if universe_id:
+ extract_story_achievements.delay(story.id, universe_id)
+
+ return story
+
+
+async def generate_full_story_service(
+ request: GenerateRequest,
+ user_id: str,
+ db: AsyncSession,
+) -> FullStoryResponse:
+ """Generate story with parallel image generation."""
+ # 1. Generate text part
+ # We can reuse logic or call generate_story_content directly if we want finer control
+ # reusing generate_and_save_story to ensure consistency (it handles validation + saving)
+ story = await generate_and_save_story(request, user_id, db)
+
+ # 2. Generate Image (Parallel/Async step in this flow)
+ image_url: str | None = None
+ errors: dict[str, str | None] = {}
+
+ if story.cover_prompt:
+ try:
+ image_url = await generate_image(story.cover_prompt, db=db)
+ story.image_url = image_url
+ await db.commit()
+ except Exception as exc:
+ errors["image"] = str(exc)
+ logger.warning("image_generation_failed", story_id=story.id, error=str(exc))
+
+ return FullStoryResponse(
+ id=story.id,
+ title=story.title,
+ story_text=story.story_text,
+ cover_prompt=story.cover_prompt,
+ image_url=image_url,
+ audio_ready=False,
+ mode=story.mode,
+ errors=errors,
+ child_profile_id=story.child_profile_id,
+ universe_id=story.universe_id,
+ )
+
+
+async def generate_storybook_service(
+ request: StorybookRequest,
+ user_id: str,
+ db: AsyncSession,
+) -> StorybookResponse:
+ """Generate storybook with parallel image generation for pages."""
+ # 1. Validate
+ profile_id, universe_id = await validate_profile_and_universe(
+ request.child_profile_id, request.universe_id, user_id, db
+ )
+
+ logger.info(
+ "storybook_request",
+ user_id=user_id,
+ keywords=request.keywords,
+ page_count=request.page_count,
+ profile_id=profile_id,
+ universe_id=universe_id,
+ )
+
+ # 2. Context
+ memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
+
+ # 3. Generate Text Structure
+ try:
+ storybook = await generate_storybook(
+ keywords=request.keywords,
+ page_count=request.page_count,
+ education_theme=request.education_theme,
+ memory_context=memory_context,
+ db=db,
+ )
+ except Exception as e:
+ logger.error("storybook_generation_failed", error=str(e))
+ raise HTTPException(status_code=500, detail=f"故事书生成失败: {e}")
+
+ # 4. Parallel Image Generation
+ final_cover_url = storybook.cover_url
+ if request.generate_images:
+ logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages))
+
+ tasks = []
+
+ # Cover Task
+ async def _gen_cover():
+ if storybook.cover_prompt and not storybook.cover_url:
+ try:
+ return await generate_image(storybook.cover_prompt, db=db)
+ except Exception as e:
+ logger.warning("cover_gen_failed", error=str(e))
+ return storybook.cover_url
+ tasks.append(_gen_cover())
+
+ # Page Tasks
+ async def _gen_page(page):
+ if page.image_prompt and not page.image_url:
+ try:
+ url = await generate_image(page.image_prompt, db=db)
+ page.image_url = url
+ except Exception as e:
+ logger.warning("page_gen_failed", page=page.page_number, error=str(e))
+
+ for page in storybook.pages:
+ tasks.append(_gen_page(page))
+
+ # Execute
+ results = await asyncio.gather(*tasks, return_exceptions=True)
+
+ # Update cover result
+ cover_res = results[0]
+ if isinstance(cover_res, str):
+ final_cover_url = cover_res
+
+ logger.info("storybook_parallel_generation_complete")
+
+ # 5. Save to DB
+ pages_data = [
+ {
+ "page_number": p.page_number,
+ "text": p.text,
+ "image_prompt": p.image_prompt,
+ "image_url": p.image_url,
+ }
+ for p in storybook.pages
+ ]
+
+ story = Story(
+ user_id=user_id,
+ child_profile_id=profile_id,
+ universe_id=universe_id,
+ title=storybook.title,
+ mode="storybook",
+ pages=pages_data,
+ story_text=None,
+ cover_prompt=storybook.cover_prompt,
+ image_url=final_cover_url,
+ )
+ db.add(story)
+ await db.commit()
+ await db.refresh(story)
+
+ if universe_id:
+ extract_story_achievements.delay(story.id, universe_id)
+
+ # 6. Build Response
+ response_pages = [
+ StorybookPageResponse(
+ page_number=p["page_number"],
+ text=p["text"],
+ image_prompt=p["image_prompt"],
+ image_url=p.get("image_url"),
+ )
+ for p in pages_data
+ ]
+
+ return StorybookResponse(
+ id=story.id,
+ title=storybook.title,
+ main_character=storybook.main_character,
+ art_style=storybook.art_style,
+ pages=response_pages,
+ cover_prompt=storybook.cover_prompt,
+ cover_url=final_cover_url,
+ )
+
+
+# ==================== Missing Endpoints Logic (for Issue #5) ====================
+
+async def list_stories(
+ user_id: str,
+ limit: int,
+ offset: int,
+ db: AsyncSession,
+) -> list[Story]:
+ """List stories for user."""
+ result = await db.execute(
+ select(Story)
+ .where(Story.user_id == user_id)
+ .order_by(desc(Story.created_at))
+ .offset(offset)
+ .limit(limit)
+ )
+ return result.scalars().all()
+
+
+async def get_story_detail(
+ story_id: int,
+ user_id: str,
+ db: AsyncSession,
+) -> Story:
+ """Get story detail."""
+ result = await db.execute(
+ select(Story).where(Story.id == story_id, Story.user_id == user_id)
+ )
+ story = result.scalar_one_or_none()
+ if not story:
+ raise HTTPException(status_code=404, detail="Story not found")
+ return story
+
+
+async def delete_story(
+ story_id: int,
+ user_id: str,
+ db: AsyncSession,
+) -> None:
+ """Delete a story."""
+ story = await get_story_detail(story_id, user_id, db)
+ await db.delete(story)
+ await db.commit()
+
+
+async def create_story_from_result(
+ result, # StoryOutput
+ user_id: str,
+ profile_id: str | None,
+ universe_id: str | None,
+ db: AsyncSession,
+) -> Story:
+ """Save a generated story to DB (helper for stream endpoint)."""
+ story = Story(
+ user_id=user_id,
+ child_profile_id=profile_id,
+ universe_id=universe_id,
+ title=result.title,
+ story_text=result.story_text,
+ cover_prompt=result.cover_prompt_suggestion,
+ mode=result.mode,
+ )
+ db.add(story)
+ await db.commit()
+ await db.refresh(story)
+
+ if universe_id:
+ extract_story_achievements.delay(story.id, universe_id)
+
+ return story
+
+
+async def generate_story_cover(
+ story_id: int,
+ user_id: str,
+ db: AsyncSession,
+) -> str:
+ """Generate cover image for an existing story."""
+ story = await get_story_detail(story_id, user_id, db)
+
+ if not story.cover_prompt:
+ raise HTTPException(status_code=400, detail="Story has no cover prompt")
+
+ try:
+ image_url = await generate_image(story.cover_prompt, db=db)
+ story.image_url = image_url
+ await db.commit()
+ return image_url
+ except Exception as e:
+ logger.error("cover_generation_failed", story_id=story_id, error=str(e))
+ raise HTTPException(status_code=500, detail=f"Image generation failed: {e}")
+
+
+async def generate_story_audio(
+ story_id: int,
+ user_id: str,
+ db: AsyncSession,
+) -> bytes:
+ """Generate audio for a story."""
+ story = await get_story_detail(story_id, user_id, db)
+
+ if not story.story_text:
+ raise HTTPException(status_code=400, detail="Story has no text")
+
+ # TODO: Check if audio is already cached/saved?
+ # For now, generate on the fly via provider
+ from app.services.provider_router import text_to_speech
+
+ try:
+ audio_data = await text_to_speech(story.story_text, db=db)
+ return audio_data
+ except Exception as e:
+ logger.error("audio_generation_failed", story_id=story_id, error=str(e))
+ raise HTTPException(status_code=500, detail=f"Audio generation failed: {e}")
+
+
+async def get_story_achievements(
+ story_id: int,
+ user_id: str,
+ db: AsyncSession,
+) -> list[AchievementItem]:
+ """Get achievements unlocked by a specific story."""
+ result = await db.execute(
+ select(Story)
+ .options(joinedload(Story.story_universe))
+ .where(Story.id == story_id, Story.user_id == user_id)
+ )
+ story = result.scalar_one_or_none()
+
+ if not story:
+ raise HTTPException(status_code=404, detail="Story not found")
+
+ if not story.universe_id or not story.story_universe:
+ return []
+
+ universe = story.story_universe
+ if not universe.achievements:
+ return []
+
+ results = []
+ for ach in universe.achievements:
+ if isinstance(ach, dict) and ach.get("source_story_id") == story_id:
+ results.append(AchievementItem(
+ type=ach.get("type", "Unknown"),
+ description=ach.get("description", ""),
+ obtained_at=ach.get("obtained_at")
+ ))
+ return results
+
diff --git a/backend/docs/memory_system_prd.md b/backend/docs/memory_system_prd.md
deleted file mode 100644
index 4076ad8..0000000
--- a/backend/docs/memory_system_prd.md
+++ /dev/null
@@ -1,93 +0,0 @@
-# 梦语织机 (DreamWeaver) 记忆系统升级 PRD
-> 版本: v1.0 | 状态: 规划中 | 优先级: High
-
-## 1. 核心愿景 (Vision)
-
-将当前的"数据存储"升级为有温度的**"情感连接系统"**。
-我们不只是在记住数据,而是在**维护孩子与故事世界的关系**。让每一个故事不再是孤立的碎片,而是构建孩子专属"故事宇宙"的砖瓦。
-
----
-
-## 2. 产品痛点与解决方案
-
-| 用户角色 | 核心痛点 | 解决方案 | 预期价值 |
-|---------|---------|---------|---------|
-| **孩子** | "上次的小兔子怎么不认识我了?"
故事之间缺乏连续性,只有单次体验。 | **角色一致性与记忆注入**
故事开头主动提及往事,角色性格延续。 | 建立情感依恋,提升沉浸感。 |
-| **家长** | "这App除了生成故事还能干嘛?"
无法感知产品的长期教育价值。 | **显性化成长轨迹**
词汇量统计、主题变化、成就徽章可视化。 | 提高付费意愿,提供社交货币。 |
-| **平台** | 用户用完即走,缺乏留存壁垒。 | **沉没成本与情感资产**
积累的记忆越多,越舍不得离开。 | 提升长期留存率 (LTV)。 |
-
----
-
-## 3. 功能架构:记忆分层模型
-
-### 3.1 层级 1: 核心档案 (Identity Layer)
-*性质:永久、静态、显性*
-* **数据**: 姓名、年龄、性别。
-* **输入**: 家长在 Onboarding 阶段手动输入。
-* **作用**: 决定故事的基础适龄性和称呼。
-
-### 3.2 层级 2: 故事宇宙 (Universe Layer)
-*性质:长期、动态积累、半显性*
-* **主角设定**: 姓名、性格特征(勇敢/害羞)、外貌特征(戴眼镜/卷发)。
-* **常驻配角**: 从随机故事中涌现出的固定伙伴(如"爱吃胡萝卜的松鼠奇奇")。
-* **世界观**: 故事发生的背景(魔法森林、未来城市、海底世界)。
-* **成就系统**: 孩子获得的虚拟奖励(勇气勋章、小小探险家)。
-
-### 3.3 层级 3: 工作记忆 (Working Memory)
-*性质:短期、自动衰减、隐性*
-* **关键情节**: 最近 3 个故事的结局和核心冲突。
-* **情感标记**: 孩子对特定内容的反应(根据“重播”、“跳过”推断)。
-* **新学词汇**: 故事中出现的高级词汇。
-
----
-
-## 4. 关键功能特性 (Feature Specs)
-
-### 4.1 智能开场白 (Memory Injection)
-在生成新故事时,Prompt 必须包含一段"记忆唤醒"指令。
-* **示例**: "小明,还记得上周我们帮小松鼠找回了松果吗?今天,小松鼠带来了一位新朋友..."
-* **策略**: 提取权重最高的 Top 3 记忆注入 Prompt。
-
-### 4.2 成长时间轴 (Growth Timeline)
-一个可视化的 H5 页面或 App 模块,以时间轴形式展示里程碑。
-* **节点类型**:
- * 🌟 **初次相遇**: 创建角色的第一天。
- * 📖 **阅读打卡**: 累计阅读 10/50/100 本。
- * 🏅 **获得成就**: 获得"诚实勋章"。
- * 🧠 **能力解锁**: 第一次阅读"科幻"题材。
-
-### 4.3 成就仪式感 (Achievement Ceremony)
-* **触发**: 故事生成并分析后,如果获得新成就。
-* **表现**: 弹窗动画 + 音效 + "恭喜获得 [勇气] 徽章"。
-* **分享**:允许生成带二维码的成就海报。
-
----
-
-## 5. 记忆类型扩展 (Memory Types)
-
-| 类型 Key | 描述 | 来源 | 过期策略 |
-|---------|------|------|---------|
-| `recent_story` | 最近读过的故事梗概 | 阅读事件 | 30天衰减 |
-| `favorite_character` | 孩子喜欢的角色 | 重播/高评分 | 长期有效 |
-| `scary_element` | 孩子害怕/不喜欢的元素 | 跳过/负反馈 | 长期有效 (避雷) |
-| `vocabulary_growth` | 新掌握的词汇 | 故事分析 | 90天衰减 |
-| `emotional_highlight` | 高光时刻 (如: 特别开心的情节) | 互动数据 | 60天衰减 |
-
----
-
-## 6. 实施路线图 (Roadmap)
-
-### Phase 1: 基础建设 (v0.3.0)
-* [x] 数据库 `MemoryItem` 表 (已存在)。
-* [ ] 扩展 `MemoryItem` 类型字段,支持更多维度。
-* [ ] 优化 `_build_memory_context`,支持更自然的 Prompt 注入。
-* [ ] 前端:简单的"近期回忆"展示列表。
-
-### Phase 2: 可视化与成就 (v0.4.0)
-* [ ] 实现"成就提取器" (Achievement Extractor) 的闭环通知。
-* [ ] 前端:开发"我的成就"和"成长时间轴"页面。
-* [ ] 增加故事开场白的动态生成逻辑。
-
-### Phase 3: 深度智能 (v0.5.0+)
-* [ ] 引入向量数据库,实现基于语义的记忆检索 (不仅是时间最近)。
-* [ ] 情感分析模型:分析用户行为推断情感倾向。
diff --git a/backend/docs/refactoring_plan.md b/backend/docs/refactoring_plan.md
new file mode 100644
index 0000000..73a59f2
--- /dev/null
+++ b/backend/docs/refactoring_plan.md
@@ -0,0 +1,98 @@
+# DreamWeaver 重构实施计划
+
+## 1. 概述
+
+本文档基于对当前架构的深入分析,制定了从稳定性、可维护性到可扩展性的分阶段重构计划。
+
+**目标**:
+- **短期**:解决单点故障风险,优化开发体验,清理关键技术债。
+- **中期**:提升系统高可用能力,增强监控与可观测性。
+- **长期**:架构演进,支持大规模并发与复杂业务场景。
+
+---
+
+## 2. 短期优化计划 (1-2周)
+
+**重点**:消除即时风险,提升部署效率。
+
+### 2.1 统一镜像构建 (High Priority)
+目前 `backend`, `backend-admin`, `worker`, `celery-beat` 重复构建 4 次,浪费资源且镜像版本可能不一致。
+
+- **Action Items**:
+ - [ ] 修改 `backend/Dockerfile` 为通用基础镜像。
+ - [ ] 更新 `docker-compose.yml`,定义 `backend-base` 服务或使用 `image` 标签共享镜像。
+ - [ ] 确保所有 Python 服务共用同一构建产物,仅启动命令不同。
+
+### 2.2 修复 Provider 缓存与限流 (High Priority)
+内存缓存 (`TTLCache`, `_latency_cache`) 在多进程/多实例下失效。
+
+- **Action Items**:
+ - [ ] 引入 Redis 作为共享缓存后端。
+ - [ ] 重构 `_load_provider_cache`,将 Provider 配置缓存至 Redis。
+ - [ ] 重构 `stories.py` 中的限流逻辑,使用 `redis-cell` 或简单的 Redis 计数器替代 `TTLCache`。
+
+### 2.3 拆分 `stories.py` (Medium Priority)
+`app/api/stories.py` 超过 600 行,包含 API 定义、业务逻辑、验证逻辑,维护困难。
+
+- **Action Items**:
+ - [ ] 创建 `app/services/story_service.py`,迁移生成、润色、PDF生成等核心逻辑。
+ - [ ] 创建 `app/schemas/story_schema.py`,迁移 Pydantic 模型(`GenerateRequest`, `StoryResponse` 等)。
+ - [ ] API 层 `stories.py` 仅保留路由定义和依赖注入,调用 Service 层。
+
+---
+
+## 3. 中期优化计划 (1-2月)
+
+**重点**:高可用 (HA) 与系统韧性。
+
+### 3.1 数据库高可用 (Critical)
+当前 PostgreSQL 为单点,且 Admin/User 混合使用。
+
+- **Action Items**:
+ - [ ] 部署 PostgreSQL 主从复制 (Master-Slave)。
+ - [ ] 配置 `PgBouncer` 或 SQLAlchemy 读写分离,减轻主库压力。
+ - [ ] 实施数据库自动备份策略 (如 `pg_dump` 定时上传 S3)。
+
+### 3.2 消息队列高可用 (Critical)
+Redis 单点故障将导致 Celery 任务全盘停摆。
+
+- **Action Items**:
+ - [ ] 迁移至 Redis Sentinel 或 Redis Cluster 模式。
+ - [ ] 更新 Celery 配置以支持 Sentinel/Cluster 连接串。
+
+### 3.3 增强可观测性 (Important)
+目前仅有简单的日志,缺乏系统级指标。
+
+- **Action Items**:
+ - [ ] 集成 Prometheus Client,暴露 `/metrics` 端点。
+ - [ ] 部署 Grafana + Prometheus,监控 API 延迟、QPS、Celery 队列积压情况。
+ - [ ] 完善 `ProviderMetrics`,增加可视化大盘,实时监控 AI 供应商的成本与成功率。
+
+---
+
+## 4. 长期架构演进 (季度规划)
+
+**重点**:业务解耦与规模化。
+
+### 4.1 统一 API 网关
+- **当前**:前端直连后端端口,CORS 配置分散。
+- **演进**:引入 Traefik 或 Nginx 作为统一网关,管理路由、SSL、全局限流、统一鉴权。
+
+### 4.2 前端工程合并
+- **当前**:User App 和 Admin Console 是完全独立的两个项目,但在组件和工具链上高度重复。
+- **演进**:使用一种 Monorepo 策略或基于路由的单一应用策略,共享组件库和类型定义,减少维护成本。
+
+### 4.3 事件驱动架构完善
+- **当前**:部分业务逻辑耦合在 API 中。
+- **演进**:扩展事件总线,将“阅读记录”、“成就解锁”、“通知推送”等非核心链路完全异步化,通过 Domain Events 解耦。
+
+---
+
+## 5. 实施路线图
+
+| 阶段 | 时间估算 | 关键里程碑 |
+| :--- | :--- | :--- |
+| **Phase 1: 基础夯实** | Week 1-2 | Docker 构建优化上线,Redis 替代内存缓存。 |
+| **Phase 2: 代码重构** | Week 3-4 | `stories.py` 拆分完成,Service 层建立。 |
+| **Phase 3: 高可用建设** | Month 2 | 数据库与 Redis 实现主备/集群模式。 |
+| **Phase 4: 监控体系** | Month 2 | Grafana 监控大盘上线,关键指标报警配置完毕。 |
diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py
index a048af1..19786a0 100644
--- a/backend/tests/conftest.py
+++ b/backend/tests/conftest.py
@@ -121,7 +121,7 @@ def mock_text_provider():
cover_prompt_suggestion="A cute rabbit",
)
- with patch("app.api.stories.generate_story_content", new_callable=AsyncMock) as mock:
+ with patch("app.services.story_service.generate_story_content", new_callable=AsyncMock) as mock:
mock.return_value = mock_result
yield mock
@@ -129,7 +129,7 @@ def mock_text_provider():
@pytest.fixture
def mock_image_provider():
"""Mock 图像生成。"""
- with patch("app.api.stories.generate_image", new_callable=AsyncMock) as mock:
+ with patch("app.services.story_service.generate_image", new_callable=AsyncMock) as mock:
mock.return_value = "https://example.com/image.png"
yield mock
@@ -137,7 +137,7 @@ def mock_image_provider():
@pytest.fixture
def mock_tts_provider():
"""Mock TTS。"""
- with patch("app.api.stories.text_to_speech", new_callable=AsyncMock) as mock:
+ with patch("app.services.provider_router.text_to_speech", new_callable=AsyncMock) as mock:
mock.return_value = b"fake-audio-bytes"
yield mock
diff --git a/backend/tests/test_stories.py b/backend/tests/test_stories.py
index b34b840..121f3fd 100644
--- a/backend/tests/test_stories.py
+++ b/backend/tests/test_stories.py
@@ -57,7 +57,6 @@ class TestStoryGenerate:
assert data["mode"] == "generated"
-@pytest.mark.skip(reason="GET /api/stories (列表) 端点尚未实现")
class TestStoryList:
"""故事列表测试。"""
@@ -94,7 +93,6 @@ class TestStoryList:
assert len(data) == 0
-@pytest.mark.skip(reason="GET /api/stories/{id} (详情) 端点尚未实现")
class TestStoryDetail:
"""故事详情测试。"""
@@ -118,7 +116,6 @@ class TestStoryDetail:
assert data["story_text"] == test_story.story_text
-@pytest.mark.skip(reason="DELETE /api/stories/{id} (删除) 端点尚未实现")
class TestStoryDelete:
"""故事删除测试。"""
@@ -168,7 +165,6 @@ class TestRateLimit:
assert "Too many requests" in response.json()["detail"]
-@pytest.mark.skip(reason="POST /api/image/generate/{id} 端点尚未实现")
class TestImageGenerate:
"""封面图片生成测试。"""
@@ -183,7 +179,6 @@ class TestImageGenerate:
assert response.status_code == 404
-@pytest.mark.skip(reason="GET /api/audio/{id} 端点尚未实现")
class TestAudio:
"""语音朗读测试。"""
@@ -233,7 +228,7 @@ class TestGenerateFull:
def test_generate_full_image_failure(self, auth_client: TestClient, mock_text_provider):
"""图片生成失败时返回部分成功。"""
- with patch("app.api.stories.generate_image", new_callable=AsyncMock) as mock_img:
+ with patch("app.services.story_service.generate_image", new_callable=AsyncMock) as mock_img:
mock_img.side_effect = Exception("Image API error")
response = auth_client.post(
"/api/stories/generate/full",
@@ -261,7 +256,6 @@ class TestGenerateFull:
assert call_kwargs["education_theme"] == "勇气与友谊"
-@pytest.mark.skip(reason="POST /api/image/generate/{id} 端点尚未实现")
class TestImageGenerateSuccess:
"""封面图片生成成功测试。"""