Refactor Phase 2: Split stories.py into Schema/Service/Controller, add missing endpoints, fix async bug
This commit is contained in:
@@ -2,26 +2,32 @@ FROM python:3.11-slim
|
|||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# 安装系统依赖 (如果需要)
|
# 设置环境变量
|
||||||
# RUN apt-get update && apt-get install -y gcc libpq-dev && rm -rf /var/lib/apt/lists/*
|
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||||
|
PYTHONUNBUFFERED=1
|
||||||
|
|
||||||
# 复制项目文件
|
# 安装系统工具 (curl用于可能的健康检查)
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends curl && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# 1. 缓存层:仅复制依赖定义并安装
|
||||||
|
# 创建伪造的 app 目录以满足 pip install . 的要求
|
||||||
COPY pyproject.toml .
|
COPY pyproject.toml .
|
||||||
# 复制源码
|
RUN mkdir app && touch app/__init__.py
|
||||||
|
RUN pip install --no-cache-dir .
|
||||||
|
|
||||||
|
# 2. 源码层:复制真实代码
|
||||||
COPY app ./app
|
COPY app ./app
|
||||||
COPY alembic ./alembic
|
COPY alembic ./alembic
|
||||||
COPY alembic.ini .
|
COPY alembic.ini .
|
||||||
|
|
||||||
# 安装依赖
|
# 再次安装本身(不带依赖),确保源码更新被标记为已安装
|
||||||
# 使用 pip 安装当前目录 (.),会自动解析 pyproject.toml
|
RUN pip install --no-cache-dir --no-deps .
|
||||||
RUN pip install --no-cache-dir .
|
|
||||||
|
|
||||||
# 创建静态文件目录 (用于存放生成的图片)
|
# 创建静态文件目录
|
||||||
RUN mkdir -p static/images
|
RUN mkdir -p static/images
|
||||||
|
|
||||||
# 暴露端口
|
# 暴露端口
|
||||||
EXPOSE 8000
|
EXPOSE 8000
|
||||||
|
|
||||||
# 启动命令
|
# 默认启动命令(可被 docker-compose 覆盖)
|
||||||
# 生产环境建议使用 gunicorn 或 uvicorn --workers
|
|
||||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||||
|
|||||||
@@ -2,136 +2,41 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import time
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import AsyncGenerator, Literal
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||||
from fastapi.responses import Response
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.orm import joinedload
|
|
||||||
from sse_starlette.sse import EventSourceResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.core.deps import require_user
|
from app.core.deps import require_user
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
from app.db.database import get_db
|
|
||||||
from app.db.models import ChildProfile, Story, StoryUniverse, User
|
|
||||||
from app.services.provider_router import (
|
|
||||||
generate_image,
|
|
||||||
generate_story_content,
|
|
||||||
generate_storybook,
|
|
||||||
text_to_speech,
|
|
||||||
)
|
|
||||||
from app.core.rate_limiter import check_rate_limit
|
from app.core.rate_limiter import check_rate_limit
|
||||||
from app.tasks.achievements import extract_story_achievements
|
from app.db.database import get_db
|
||||||
|
from app.db.models import User
|
||||||
|
from app.schemas.story_schemas import (
|
||||||
|
GenerateRequest,
|
||||||
|
StoryResponse,
|
||||||
|
FullStoryResponse,
|
||||||
|
StorybookRequest,
|
||||||
|
StorybookResponse,
|
||||||
|
StoryListItem,
|
||||||
|
AchievementItem,
|
||||||
|
)
|
||||||
|
from app.services import story_service
|
||||||
|
from app.services.memory_service import build_enhanced_memory_context
|
||||||
|
from app.services.provider_router import (
|
||||||
|
generate_story_content,
|
||||||
|
generate_image,
|
||||||
|
)
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
MAX_DATA_LENGTH = 2000
|
|
||||||
MAX_EDU_THEME_LENGTH = 200
|
|
||||||
MAX_TTS_LENGTH = 4000
|
|
||||||
|
|
||||||
RATE_LIMIT_WINDOW = 60 # seconds
|
RATE_LIMIT_WINDOW = 60 # seconds
|
||||||
RATE_LIMIT_REQUESTS = 10
|
RATE_LIMIT_REQUESTS = 10
|
||||||
|
|
||||||
|
|
||||||
class GenerateRequest(BaseModel):
|
|
||||||
"""Story generation request."""
|
|
||||||
|
|
||||||
type: Literal["keywords", "full_story"]
|
|
||||||
data: str = Field(..., min_length=1, max_length=MAX_DATA_LENGTH)
|
|
||||||
education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH)
|
|
||||||
child_profile_id: str | None = None
|
|
||||||
universe_id: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class StoryResponse(BaseModel):
|
|
||||||
"""Story response."""
|
|
||||||
|
|
||||||
id: int
|
|
||||||
title: str
|
|
||||||
story_text: str
|
|
||||||
cover_prompt: str | None
|
|
||||||
image_url: str | None
|
|
||||||
mode: str
|
|
||||||
child_profile_id: str | None = None
|
|
||||||
universe_id: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class StoryListItem(BaseModel):
|
|
||||||
"""Story list item."""
|
|
||||||
|
|
||||||
id: int
|
|
||||||
title: str
|
|
||||||
image_url: str | None
|
|
||||||
created_at: str
|
|
||||||
mode: str
|
|
||||||
|
|
||||||
|
|
||||||
class FullStoryResponse(BaseModel):
|
|
||||||
"""完整故事响应(含图片和音频状态)。"""
|
|
||||||
|
|
||||||
id: int
|
|
||||||
title: str
|
|
||||||
story_text: str
|
|
||||||
cover_prompt: str | None
|
|
||||||
image_url: str | None
|
|
||||||
audio_ready: bool
|
|
||||||
mode: str
|
|
||||||
errors: dict[str, str | None] = Field(default_factory=dict)
|
|
||||||
child_profile_id: str | None = None
|
|
||||||
universe_id: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
from app.services.memory_service import build_enhanced_memory_context
|
|
||||||
|
|
||||||
|
|
||||||
async def _validate_profile_and_universe(
|
|
||||||
request: GenerateRequest,
|
|
||||||
user: User,
|
|
||||||
db: AsyncSession,
|
|
||||||
) -> tuple[str | None, str | None]:
|
|
||||||
if not request.child_profile_id and not request.universe_id:
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
profile_id = request.child_profile_id
|
|
||||||
universe_id = request.universe_id
|
|
||||||
|
|
||||||
if profile_id:
|
|
||||||
result = await db.execute(
|
|
||||||
select(ChildProfile).where(
|
|
||||||
ChildProfile.id == profile_id,
|
|
||||||
ChildProfile.user_id == user.id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
profile = result.scalar_one_or_none()
|
|
||||||
if not profile:
|
|
||||||
raise HTTPException(status_code=404, detail="孩子档案不存在")
|
|
||||||
|
|
||||||
if universe_id:
|
|
||||||
result = await db.execute(
|
|
||||||
select(StoryUniverse)
|
|
||||||
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
|
|
||||||
.where(
|
|
||||||
StoryUniverse.id == universe_id,
|
|
||||||
ChildProfile.user_id == user.id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
universe = result.scalar_one_or_none()
|
|
||||||
if not universe:
|
|
||||||
raise HTTPException(status_code=404, detail="故事宇宙不存在")
|
|
||||||
if profile_id and universe.child_profile_id != profile_id:
|
|
||||||
raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配")
|
|
||||||
if not profile_id:
|
|
||||||
profile_id = universe.child_profile_id
|
|
||||||
|
|
||||||
return profile_id, universe_id
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/stories/generate", response_model=StoryResponse)
|
@router.post("/stories/generate", response_model=StoryResponse)
|
||||||
async def generate_story(
|
async def generate_story(
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
@@ -140,47 +45,7 @@ async def generate_story(
|
|||||||
):
|
):
|
||||||
"""Generate or enhance a story."""
|
"""Generate or enhance a story."""
|
||||||
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
|
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
|
||||||
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
|
return await story_service.generate_and_save_story(request, user.id, db)
|
||||||
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = await generate_story_content(
|
|
||||||
input_type=request.type,
|
|
||||||
data=request.data,
|
|
||||||
education_theme=request.education_theme,
|
|
||||||
memory_context=memory_context,
|
|
||||||
)
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception:
|
|
||||||
raise HTTPException(status_code=502, detail="Story generation failed, please try again.")
|
|
||||||
|
|
||||||
story = Story(
|
|
||||||
user_id=user.id,
|
|
||||||
child_profile_id=profile_id,
|
|
||||||
universe_id=universe_id,
|
|
||||||
title=result.title,
|
|
||||||
story_text=result.story_text,
|
|
||||||
cover_prompt=result.cover_prompt_suggestion,
|
|
||||||
mode=result.mode,
|
|
||||||
)
|
|
||||||
db.add(story)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(story)
|
|
||||||
|
|
||||||
if universe_id:
|
|
||||||
extract_story_achievements.delay(story.id, universe_id)
|
|
||||||
|
|
||||||
return StoryResponse(
|
|
||||||
id=story.id,
|
|
||||||
title=story.title,
|
|
||||||
story_text=story.story_text,
|
|
||||||
cover_prompt=story.cover_prompt,
|
|
||||||
image_url=story.image_url,
|
|
||||||
mode=story.mode,
|
|
||||||
child_profile_id=story.child_profile_id,
|
|
||||||
universe_id=story.universe_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/stories/generate/full", response_model=FullStoryResponse)
|
@router.post("/stories/generate/full", response_model=FullStoryResponse)
|
||||||
@@ -189,71 +54,9 @@ async def generate_story_full(
|
|||||||
user: User = Depends(require_user),
|
user: User = Depends(require_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""生成完整故事(故事 + 并行生成图片和音频)。
|
"""Generate complete story (story + parallel image/audio generation)."""
|
||||||
|
|
||||||
部分成功策略:故事必须成功,图片/音频失败不影响整体。
|
|
||||||
"""
|
|
||||||
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
|
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
|
||||||
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
|
return await story_service.generate_full_story_service(request, user.id, db)
|
||||||
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
|
||||||
|
|
||||||
# Step 1: 故事生成(必须成功)
|
|
||||||
try:
|
|
||||||
result = await generate_story_content(
|
|
||||||
input_type=request.type,
|
|
||||||
data=request.data,
|
|
||||||
education_theme=request.education_theme,
|
|
||||||
memory_context=memory_context,
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("story_generation_failed", error=str(exc))
|
|
||||||
raise HTTPException(status_code=502, detail="Story generation failed, please try again.")
|
|
||||||
|
|
||||||
# 保存故事
|
|
||||||
story = Story(
|
|
||||||
user_id=user.id,
|
|
||||||
child_profile_id=profile_id,
|
|
||||||
universe_id=universe_id,
|
|
||||||
title=result.title,
|
|
||||||
story_text=result.story_text,
|
|
||||||
cover_prompt=result.cover_prompt_suggestion,
|
|
||||||
mode=result.mode,
|
|
||||||
)
|
|
||||||
db.add(story)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(story)
|
|
||||||
|
|
||||||
if universe_id:
|
|
||||||
extract_story_achievements.delay(story.id, universe_id)
|
|
||||||
|
|
||||||
# Step 2: 生成封面图片(音频按需生成,避免浪费)
|
|
||||||
errors: dict[str, str | None] = {}
|
|
||||||
image_url: str | None = None
|
|
||||||
|
|
||||||
if story.cover_prompt:
|
|
||||||
try:
|
|
||||||
image_url = await generate_image(story.cover_prompt)
|
|
||||||
story.image_url = image_url
|
|
||||||
await db.commit()
|
|
||||||
except Exception as exc:
|
|
||||||
errors["image"] = str(exc)
|
|
||||||
logger.warning("image_generation_failed", story_id=story.id, error=str(exc))
|
|
||||||
|
|
||||||
# 注意:音频不在此处预生成,用户通过 /api/audio/{id} 按需获取
|
|
||||||
# 这样避免生成后丢弃造成的成本浪费
|
|
||||||
|
|
||||||
return FullStoryResponse(
|
|
||||||
id=story.id,
|
|
||||||
title=story.title,
|
|
||||||
story_text=story.story_text,
|
|
||||||
cover_prompt=story.cover_prompt,
|
|
||||||
image_url=image_url,
|
|
||||||
audio_ready=False, # 音频需要用户主动请求
|
|
||||||
mode=story.mode,
|
|
||||||
errors=errors,
|
|
||||||
child_profile_id=story.child_profile_id,
|
|
||||||
universe_id=story.universe_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/stories/generate/stream")
|
@router.post("/stories/generate/stream")
|
||||||
@@ -263,53 +66,39 @@ async def generate_story_stream(
|
|||||||
user: User = Depends(require_user),
|
user: User = Depends(require_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""流式生成故事(SSE)。
|
"""流式生成故事(SSE)。"""
|
||||||
|
|
||||||
事件流程:
|
|
||||||
- started: 返回 story_id
|
|
||||||
- story_ready: 返回 title, content
|
|
||||||
- story_failed: 返回 error
|
|
||||||
- image_ready: 返回 image_url
|
|
||||||
- image_failed: 返回 error
|
|
||||||
- complete: 结束流
|
|
||||||
"""
|
|
||||||
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
|
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
|
||||||
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
|
|
||||||
|
# Validation
|
||||||
|
profile_id, universe_id = await story_service.validate_profile_and_universe(
|
||||||
|
request.child_profile_id, request.universe_id, user.id, db
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build Context
|
||||||
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[dict, None]:
|
async def event_generator() -> AsyncGenerator[dict, None]:
|
||||||
story_id = str(uuid.uuid4())
|
story_id = str(uuid.uuid4())
|
||||||
yield {"event": "started", "data": json.dumps({"story_id": story_id})}
|
yield {"event": "started", "data": json.dumps({"story_id": story_id})}
|
||||||
|
|
||||||
# Step 1: 生成故事
|
# Step 1: Generate Content
|
||||||
try:
|
try:
|
||||||
result = await generate_story_content(
|
result = await generate_story_content(
|
||||||
input_type=request.type,
|
input_type=request.type,
|
||||||
data=request.data,
|
data=request.data,
|
||||||
education_theme=request.education_theme,
|
education_theme=request.education_theme,
|
||||||
memory_context=memory_context,
|
memory_context=memory_context,
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("sse_story_generation_failed", error=str(e))
|
logger.error("sse_story_generation_failed", error=str(e))
|
||||||
yield {"event": "story_failed", "data": json.dumps({"error": str(e)})}
|
yield {"event": "story_failed", "data": json.dumps({"error": str(e)})}
|
||||||
return
|
return
|
||||||
|
|
||||||
# 保存故事
|
# Save Story
|
||||||
story = Story(
|
story = await story_service.create_story_from_result(
|
||||||
user_id=user.id,
|
result, user.id, profile_id, universe_id, db
|
||||||
child_profile_id=profile_id,
|
|
||||||
universe_id=universe_id,
|
|
||||||
title=result.title,
|
|
||||||
story_text=result.story_text,
|
|
||||||
cover_prompt=result.cover_prompt_suggestion,
|
|
||||||
mode=result.mode,
|
|
||||||
)
|
)
|
||||||
db.add(story)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(story)
|
|
||||||
|
|
||||||
if universe_id:
|
|
||||||
extract_story_achievements.delay(story.id, universe_id)
|
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"event": "story_ready",
|
"event": "story_ready",
|
||||||
@@ -324,10 +113,11 @@ async def generate_story_stream(
|
|||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Step 2: 并行生成图片(音频按需)
|
# Step 2: Generate Image
|
||||||
if story.cover_prompt:
|
if story.cover_prompt:
|
||||||
try:
|
try:
|
||||||
image_url = await generate_image(story.cover_prompt)
|
# Direct call to provider router's generate_image, sharing db session
|
||||||
|
image_url = await generate_image(story.cover_prompt, db=db)
|
||||||
story.image_url = image_url
|
story.image_url = image_url
|
||||||
await db.commit()
|
await db.commit()
|
||||||
yield {"event": "image_ready", "data": json.dumps({"image_url": image_url})}
|
yield {"event": "image_ready", "data": json.dumps({"image_url": image_url})}
|
||||||
@@ -340,217 +130,71 @@ async def generate_story_stream(
|
|||||||
return EventSourceResponse(event_generator())
|
return EventSourceResponse(event_generator())
|
||||||
|
|
||||||
|
|
||||||
# ==================== Storybook API ====================
|
|
||||||
|
|
||||||
|
|
||||||
class StorybookRequest(BaseModel):
|
|
||||||
"""Storybook 生成请求。"""
|
|
||||||
|
|
||||||
keywords: str = Field(..., min_length=1, max_length=200)
|
|
||||||
page_count: int = Field(default=6, ge=4, le=12)
|
|
||||||
education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH)
|
|
||||||
generate_images: bool = Field(default=False, description="是否同时生成插图")
|
|
||||||
child_profile_id: str | None = None
|
|
||||||
universe_id: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class StorybookPageResponse(BaseModel):
|
|
||||||
"""故事书单页响应。"""
|
|
||||||
|
|
||||||
page_number: int
|
|
||||||
text: str
|
|
||||||
image_prompt: str
|
|
||||||
image_url: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class StorybookResponse(BaseModel):
|
|
||||||
"""故事书响应。"""
|
|
||||||
|
|
||||||
id: int | None = None
|
|
||||||
title: str
|
|
||||||
main_character: str
|
|
||||||
art_style: str
|
|
||||||
pages: list[StorybookPageResponse]
|
|
||||||
cover_prompt: str
|
|
||||||
cover_url: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/storybook/generate", response_model=StorybookResponse)
|
@router.post("/storybook/generate", response_model=StorybookResponse)
|
||||||
async def generate_storybook_api(
|
async def generate_storybook_api(
|
||||||
request: StorybookRequest,
|
request: StorybookRequest,
|
||||||
user: User = Depends(require_user),
|
user: User = Depends(require_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""生成分页故事书并保存。
|
"""Generate storybook."""
|
||||||
|
|
||||||
返回故事书结构,包含每页文字和图像提示词。
|
|
||||||
"""
|
|
||||||
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
|
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
|
||||||
|
return await story_service.generate_storybook_service(request, user.id, db)
|
||||||
# 验证档案和宇宙
|
|
||||||
# 复用 _validate_profile_and_universe 需要将 request 转换为 GenerateRequest 或稍微修改验证函数
|
|
||||||
# 这里我们直接手动验证,或重构验证函数。为了简单,手动调用部分逻辑。
|
|
||||||
|
|
||||||
# 构建临时的 GenerateRequest 用于验证验证函数签名(或者直接手动查库更好)
|
|
||||||
profile_id = request.child_profile_id
|
|
||||||
universe_id = request.universe_id
|
|
||||||
|
|
||||||
if profile_id:
|
|
||||||
result = await db.execute(
|
|
||||||
select(ChildProfile).where(
|
|
||||||
ChildProfile.id == profile_id,
|
|
||||||
ChildProfile.user_id == user.id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if not result.scalar_one_or_none():
|
|
||||||
raise HTTPException(status_code=404, detail="孩子档案不存在")
|
|
||||||
|
|
||||||
if universe_id:
|
|
||||||
result = await db.execute(
|
|
||||||
select(StoryUniverse)
|
|
||||||
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
|
|
||||||
.where(
|
|
||||||
StoryUniverse.id == universe_id,
|
|
||||||
ChildProfile.user_id == user.id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
universe = result.scalar_one_or_none()
|
|
||||||
if not universe:
|
|
||||||
raise HTTPException(status_code=404, detail="故事宇宙不存在")
|
|
||||||
if profile_id and universe.child_profile_id != profile_id:
|
|
||||||
raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配")
|
|
||||||
if not profile_id:
|
|
||||||
profile_id = universe.child_profile_id
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"storybook_request",
|
|
||||||
user_id=user.id,
|
|
||||||
keywords=request.keywords,
|
|
||||||
page_count=request.page_count,
|
|
||||||
profile_id=profile_id,
|
|
||||||
universe_id=universe_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 注意:generate_storybook 目前可能不支持记忆上下文注入
|
|
||||||
# 我们需要看看 generate_storybook 的签名
|
|
||||||
# 如果不支持,记忆功能在绘本模式下暂不可用,但基本参数传递是支持的
|
|
||||||
storybook = await generate_storybook(
|
|
||||||
keywords=request.keywords,
|
|
||||||
page_count=request.page_count,
|
|
||||||
education_theme=request.education_theme,
|
|
||||||
memory_context=memory_context,
|
|
||||||
db=db,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("storybook_generation_failed", error=str(e))
|
|
||||||
raise HTTPException(status_code=500, detail=f"故事书生成失败: {e}")
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
# 核心升级: 并行全量生成 (Parallel Full Rendering)
|
|
||||||
# ==============================================================================
|
|
||||||
final_cover_url = storybook.cover_url
|
|
||||||
|
|
||||||
if request.generate_images:
|
|
||||||
logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages))
|
|
||||||
|
|
||||||
# 1. 准备所有生图任务 (封面 + 所有内页)
|
|
||||||
tasks = []
|
|
||||||
|
|
||||||
# 封面任务
|
|
||||||
async def _gen_cover():
|
|
||||||
if storybook.cover_prompt and not storybook.cover_url:
|
|
||||||
try:
|
|
||||||
return await generate_image(storybook.cover_prompt, db=db)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("cover_gen_failed", error=str(e))
|
|
||||||
return storybook.cover_url
|
|
||||||
|
|
||||||
tasks.append(_gen_cover())
|
|
||||||
|
|
||||||
# 内页任务
|
|
||||||
async def _gen_page(page):
|
|
||||||
if page.image_prompt and not page.image_url:
|
|
||||||
try:
|
|
||||||
url = await generate_image(page.image_prompt, db=db)
|
|
||||||
page.image_url = url
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("page_gen_failed", page=page.page_number, error=str(e))
|
|
||||||
|
|
||||||
for page in storybook.pages:
|
|
||||||
tasks.append(_gen_page(page))
|
|
||||||
|
|
||||||
# 2. 并发执行所有任务
|
|
||||||
# 使用 return_exceptions=True 防止单张失败影响整体
|
|
||||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
||||||
|
|
||||||
# 3. 更新封面结果 (results[0] 是封面任务的返回值)
|
|
||||||
cover_res = results[0]
|
|
||||||
if isinstance(cover_res, str):
|
|
||||||
final_cover_url = cover_res
|
|
||||||
|
|
||||||
logger.info("storybook_parallel_generation_complete")
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
|
|
||||||
# 构建并保存 Story 对象
|
|
||||||
# 将 pages 对象转换为字典列表以存入 JSON 字段
|
|
||||||
pages_data = [
|
|
||||||
{
|
|
||||||
"page_number": p.page_number,
|
|
||||||
"text": p.text,
|
|
||||||
"image_prompt": p.image_prompt,
|
|
||||||
"image_url": p.image_url,
|
|
||||||
}
|
|
||||||
for p in storybook.pages
|
|
||||||
]
|
|
||||||
|
|
||||||
story = Story(
|
|
||||||
user_id=user.id,
|
|
||||||
child_profile_id=profile_id,
|
|
||||||
universe_id=universe_id,
|
|
||||||
title=storybook.title,
|
|
||||||
mode="storybook",
|
|
||||||
pages=pages_data, # 存入 JSON 字段
|
|
||||||
story_text=None, # 绘本模式下,主文本可为空,或者可以存个摘要
|
|
||||||
cover_prompt=storybook.cover_prompt,
|
|
||||||
image_url=final_cover_url,
|
|
||||||
)
|
|
||||||
db.add(story)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(story)
|
|
||||||
|
|
||||||
if universe_id:
|
|
||||||
extract_story_achievements.delay(story.id, universe_id)
|
|
||||||
|
|
||||||
# 构建响应 (使用更新后的 pages_data)
|
|
||||||
response_pages = [
|
|
||||||
StorybookPageResponse(
|
|
||||||
page_number=p["page_number"],
|
|
||||||
text=p["text"],
|
|
||||||
image_prompt=p["image_prompt"],
|
|
||||||
image_url=p.get("image_url"),
|
|
||||||
)
|
|
||||||
for p in pages_data
|
|
||||||
]
|
|
||||||
|
|
||||||
return StorybookResponse(
|
|
||||||
id=story.id,
|
|
||||||
title=storybook.title,
|
|
||||||
main_character=storybook.main_character,
|
|
||||||
art_style=storybook.art_style,
|
|
||||||
pages=response_pages,
|
|
||||||
cover_prompt=storybook.cover_prompt,
|
|
||||||
cover_url=final_cover_url,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AchievementItem(BaseModel):
|
# ==================== Missing Endpoints (Issue #5) ====================
|
||||||
type: str
|
|
||||||
description: str
|
@router.get("/stories", response_model=list[StoryListItem])
|
||||||
obtained_at: str | None = None
|
async def list_stories(
|
||||||
|
limit: int = 20,
|
||||||
|
offset: int = 0,
|
||||||
|
user: User = Depends(require_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""List stories."""
|
||||||
|
return await story_service.list_stories(user.id, limit, offset, db)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/stories/{story_id}", response_model=StoryResponse)
|
||||||
|
async def get_story(
|
||||||
|
story_id: int,
|
||||||
|
user: User = Depends(require_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Get story detail."""
|
||||||
|
return await story_service.get_story_detail(story_id, user.id, db)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/stories/{story_id}")
|
||||||
|
async def delete_story(
|
||||||
|
story_id: int,
|
||||||
|
user: User = Depends(require_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Delete story."""
|
||||||
|
await story_service.delete_story(story_id, user.id, db)
|
||||||
|
return {"message": "Deleted"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/image/generate/{story_id}")
|
||||||
|
async def generate_story_image(
|
||||||
|
story_id: int,
|
||||||
|
user: User = Depends(require_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Generate cover image for story."""
|
||||||
|
url = await story_service.generate_story_cover(story_id, user.id, db)
|
||||||
|
return {"image_url": url}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/audio/{story_id}")
|
||||||
|
async def get_story_audio(
|
||||||
|
story_id: int,
|
||||||
|
user: User = Depends(require_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Get story audio (MP3)."""
|
||||||
|
audio_bytes = await story_service.generate_story_audio(story_id, user.id, db)
|
||||||
|
return Response(content=audio_bytes, media_type="audio/mpeg")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/stories/{story_id}/achievements", response_model=list[AchievementItem])
|
@router.get("/stories/{story_id}/achievements", response_model=list[AchievementItem])
|
||||||
@@ -559,32 +203,5 @@ async def get_story_achievements(
|
|||||||
user: User = Depends(require_user),
|
user: User = Depends(require_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""Get achievements unlocked by a specific story."""
|
"""Get story achievements."""
|
||||||
# 使用 joinedload 避免 N+1 查询
|
return await story_service.get_story_achievements(story_id, user.id, db)
|
||||||
result = await db.execute(
|
|
||||||
select(Story)
|
|
||||||
.options(joinedload(Story.story_universe))
|
|
||||||
.where(Story.id == story_id, Story.user_id == user.id)
|
|
||||||
)
|
|
||||||
story = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if not story:
|
|
||||||
raise HTTPException(status_code=404, detail="Story not found")
|
|
||||||
|
|
||||||
if not story.universe_id or not story.story_universe:
|
|
||||||
return []
|
|
||||||
|
|
||||||
universe = story.story_universe
|
|
||||||
if not universe.achievements:
|
|
||||||
return []
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for ach in universe.achievements:
|
|
||||||
if isinstance(ach, dict) and ach.get("source_story_id") == story_id:
|
|
||||||
results.append(AchievementItem(
|
|
||||||
type=ach.get("type", "Unknown"),
|
|
||||||
description=ach.get("description", ""),
|
|
||||||
obtained_at=ach.get("obtained_at")
|
|
||||||
))
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|||||||
@@ -53,6 +53,9 @@ class Settings(BaseSettings):
|
|||||||
celery_broker_url: str = Field("redis://localhost:6379/0")
|
celery_broker_url: str = Field("redis://localhost:6379/0")
|
||||||
celery_result_backend: str = Field("redis://localhost:6379/0")
|
celery_result_backend: str = Field("redis://localhost:6379/0")
|
||||||
|
|
||||||
|
# Generic Redis
|
||||||
|
redis_url: str = Field("redis://localhost:6379/0", description="Redis connection URL")
|
||||||
|
|
||||||
# Admin console
|
# Admin console
|
||||||
enable_admin_console: bool = False
|
enable_admin_console: bool = False
|
||||||
admin_username: str = "admin"
|
admin_username: str = "admin"
|
||||||
|
|||||||
25
backend/app/core/redis.py
Normal file
25
backend/app/core/redis.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
"""Redis client module."""
|
||||||
|
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
from redis.asyncio import Redis, from_url
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
_redis_pool: Redis | None = None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_redis() -> Redis:
|
||||||
|
"""Get global Redis client instance."""
|
||||||
|
global _redis_pool
|
||||||
|
if _redis_pool is None:
|
||||||
|
_redis_pool = from_url(settings.redis_url, encoding="utf-8", decode_responses=True)
|
||||||
|
return _redis_pool
|
||||||
|
|
||||||
|
|
||||||
|
async def close_redis():
|
||||||
|
"""Close Redis connection."""
|
||||||
|
global _redis_pool
|
||||||
|
if _redis_pool:
|
||||||
|
await _redis_pool.close()
|
||||||
|
_redis_pool = None
|
||||||
1
backend/app/schemas/__init__.py
Normal file
1
backend/app/schemas/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""故事相关 Schema 模块。"""
|
||||||
106
backend/app/schemas/story_schemas.py
Normal file
106
backend/app/schemas/story_schemas.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
"""故事相关 Pydantic 模型。"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
MAX_DATA_LENGTH = 2000
|
||||||
|
MAX_EDU_THEME_LENGTH = 200
|
||||||
|
MAX_TTS_LENGTH = 4000
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 故事模型 ====================
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateRequest(BaseModel):
|
||||||
|
"""Story generation request."""
|
||||||
|
|
||||||
|
type: Literal["keywords", "full_story"]
|
||||||
|
data: str = Field(..., min_length=1, max_length=MAX_DATA_LENGTH)
|
||||||
|
education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH)
|
||||||
|
child_profile_id: str | None = None
|
||||||
|
universe_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class StoryResponse(BaseModel):
|
||||||
|
"""Story response."""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
title: str
|
||||||
|
story_text: str
|
||||||
|
cover_prompt: str | None
|
||||||
|
image_url: str | None
|
||||||
|
mode: str
|
||||||
|
child_profile_id: str | None = None
|
||||||
|
universe_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class StoryListItem(BaseModel):
|
||||||
|
"""Story list item."""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
title: str
|
||||||
|
image_url: str | None
|
||||||
|
created_at: datetime
|
||||||
|
mode: str
|
||||||
|
|
||||||
|
|
||||||
|
class FullStoryResponse(BaseModel):
|
||||||
|
"""完整故事响应(含图片和音频状态)。"""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
title: str
|
||||||
|
story_text: str
|
||||||
|
cover_prompt: str | None
|
||||||
|
image_url: str | None
|
||||||
|
audio_ready: bool
|
||||||
|
mode: str
|
||||||
|
errors: dict[str, str | None] = Field(default_factory=dict)
|
||||||
|
child_profile_id: str | None = None
|
||||||
|
universe_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 绘本模型 ====================
|
||||||
|
|
||||||
|
|
||||||
|
class StorybookRequest(BaseModel):
|
||||||
|
"""Storybook 生成请求。"""
|
||||||
|
|
||||||
|
keywords: str = Field(..., min_length=1, max_length=200)
|
||||||
|
page_count: int = Field(default=6, ge=4, le=12)
|
||||||
|
education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH)
|
||||||
|
generate_images: bool = Field(default=False, description="是否同时生成插图")
|
||||||
|
child_profile_id: str | None = None
|
||||||
|
universe_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class StorybookPageResponse(BaseModel):
|
||||||
|
"""故事书单页响应。"""
|
||||||
|
|
||||||
|
page_number: int
|
||||||
|
text: str
|
||||||
|
image_prompt: str
|
||||||
|
image_url: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class StorybookResponse(BaseModel):
|
||||||
|
"""故事书响应。"""
|
||||||
|
|
||||||
|
id: int | None = None
|
||||||
|
title: str
|
||||||
|
main_character: str
|
||||||
|
art_style: str
|
||||||
|
pages: list[StorybookPageResponse]
|
||||||
|
cover_prompt: str
|
||||||
|
cover_url: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 成就模型 ====================
|
||||||
|
|
||||||
|
|
||||||
|
class AchievementItem(BaseModel):
|
||||||
|
type: str
|
||||||
|
description: str
|
||||||
|
obtained_at: str | None = None
|
||||||
@@ -1,31 +1,109 @@
|
|||||||
"""In-memory cache for providers loaded from DB."""
|
"""Redis-backed cache for providers loaded from DB."""
|
||||||
|
|
||||||
|
import json
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.core.redis import get_redis
|
||||||
from app.db.admin_models import Provider
|
from app.db.admin_models import Provider
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
ProviderType = Literal["text", "image", "tts", "storybook"]
|
ProviderType = Literal["text", "image", "tts", "storybook"]
|
||||||
|
|
||||||
_cache: dict[ProviderType, list[Provider]] = defaultdict(list)
|
|
||||||
|
class CachedProvider(BaseModel):
|
||||||
|
"""Serializable provider configuration matching DB model fields."""
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
type: str
|
||||||
|
adapter: str
|
||||||
|
model: str | None = None
|
||||||
|
api_base: str | None = None
|
||||||
|
api_key: str | None = None
|
||||||
|
timeout_ms: int = 60000
|
||||||
|
max_retries: int = 1
|
||||||
|
weight: int = 1
|
||||||
|
priority: int = 0
|
||||||
|
enabled: bool = True
|
||||||
|
config_json: dict | None = None
|
||||||
|
config_ref: str | None = None
|
||||||
|
|
||||||
|
|
||||||
async def reload_providers(db: AsyncSession):
|
# Local memory fallback (L1 cache)
|
||||||
result = await db.execute(select(Provider).where(Provider.enabled == True)) # noqa: E712
|
_local_cache: dict[ProviderType, list[CachedProvider]] = defaultdict(list)
|
||||||
providers = result.scalars().all()
|
CACHE_KEY = "dreamweaver:providers:config"
|
||||||
grouped: dict[ProviderType, list[Provider]] = defaultdict(list)
|
|
||||||
for p in providers:
|
|
||||||
grouped[p.type].append(p)
|
|
||||||
# sort by priority desc, then weight desc
|
|
||||||
for k in grouped:
|
|
||||||
grouped[k].sort(key=lambda x: (x.priority, x.weight), reverse=True)
|
|
||||||
_cache.clear()
|
|
||||||
_cache.update(grouped)
|
|
||||||
return _cache
|
|
||||||
|
|
||||||
|
|
||||||
def get_providers(provider_type: ProviderType) -> list[Provider]:
|
async def reload_providers(db: AsyncSession) -> dict[ProviderType, list[CachedProvider]]:
|
||||||
return _cache.get(provider_type, [])
|
"""Reload providers from DB and update Redis cache."""
|
||||||
|
try:
|
||||||
|
result = await db.execute(select(Provider).where(Provider.enabled == True)) # noqa: E712
|
||||||
|
providers = result.scalars().all()
|
||||||
|
|
||||||
|
# Convert to Pydantic models
|
||||||
|
cached_list = []
|
||||||
|
for p in providers:
|
||||||
|
cached_list.append(CachedProvider(
|
||||||
|
id=p.id,
|
||||||
|
name=p.name,
|
||||||
|
type=p.type,
|
||||||
|
adapter=p.adapter,
|
||||||
|
model=p.model,
|
||||||
|
api_base=p.api_base,
|
||||||
|
api_key=p.api_key,
|
||||||
|
timeout_ms=p.timeout_ms,
|
||||||
|
max_retries=p.max_retries,
|
||||||
|
weight=p.weight,
|
||||||
|
priority=p.priority,
|
||||||
|
enabled=p.enabled,
|
||||||
|
config_json=p.config_json,
|
||||||
|
config_ref=p.config_ref
|
||||||
|
))
|
||||||
|
|
||||||
|
# Group by type
|
||||||
|
grouped: dict[str, list[CachedProvider]] = defaultdict(list)
|
||||||
|
for cp in cached_list:
|
||||||
|
grouped[cp.type].append(cp)
|
||||||
|
|
||||||
|
# Sort
|
||||||
|
for k in grouped:
|
||||||
|
grouped[k].sort(key=lambda x: (x.priority, x.weight), reverse=True)
|
||||||
|
|
||||||
|
# Update Redis
|
||||||
|
redis = await get_redis()
|
||||||
|
# Serialize entire dict structure
|
||||||
|
# Pydantic -> dict -> json
|
||||||
|
json_data = {k: [p.model_dump() for p in v] for k, v in grouped.items()}
|
||||||
|
await redis.set(CACHE_KEY, json.dumps(json_data))
|
||||||
|
|
||||||
|
# Update local cache
|
||||||
|
_local_cache.clear()
|
||||||
|
_local_cache.update(grouped)
|
||||||
|
return grouped
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("failed_to_reload_providers", error=str(e))
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def get_providers(provider_type: ProviderType) -> list[CachedProvider]:
|
||||||
|
"""Get providers from Redis (preferred) or local fallback."""
|
||||||
|
try:
|
||||||
|
redis = await get_redis()
|
||||||
|
data = await redis.get(CACHE_KEY)
|
||||||
|
if data:
|
||||||
|
raw_dict = json.loads(data)
|
||||||
|
if provider_type in raw_dict:
|
||||||
|
return [CachedProvider(**item) for item in raw_dict[provider_type]]
|
||||||
|
return []
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("redis_cache_read_failed", error=str(e))
|
||||||
|
|
||||||
|
# Fallback to local memory
|
||||||
|
return _local_cache.get(provider_type, [])
|
||||||
|
|||||||
@@ -177,7 +177,7 @@ def _build_config_from_provider(provider: "Provider") -> AdapterConfig:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_providers_with_config(
|
async def _get_providers_with_config(
|
||||||
provider_type: ProviderType,
|
provider_type: ProviderType,
|
||||||
) -> list[tuple[str, AdapterConfig, "Provider | None"]]:
|
) -> list[tuple[str, AdapterConfig, "Provider | None"]]:
|
||||||
"""获取供应商列表及其配置。
|
"""获取供应商列表及其配置。
|
||||||
@@ -185,7 +185,7 @@ def _get_providers_with_config(
|
|||||||
Returns:
|
Returns:
|
||||||
[(adapter_name, config, provider_or_none), ...] 按优先级排序
|
[(adapter_name, config, provider_or_none), ...] 按优先级排序
|
||||||
"""
|
"""
|
||||||
db_providers = get_providers(provider_type)
|
db_providers = await get_providers(provider_type)
|
||||||
|
|
||||||
if db_providers:
|
if db_providers:
|
||||||
return [(p.adapter, _build_config_from_provider(p), p) for p in db_providers]
|
return [(p.adapter, _build_config_from_provider(p), p) for p in db_providers]
|
||||||
@@ -265,7 +265,7 @@ async def _route_with_failover(
|
|||||||
user_id: 用户 ID(可选,用于成本追踪和预算检查)
|
user_id: 用户 ID(可选,用于成本追踪和预算检查)
|
||||||
**kwargs: 传递给适配器的参数
|
**kwargs: 传递给适配器的参数
|
||||||
"""
|
"""
|
||||||
providers = _get_providers_with_config(provider_type)
|
providers = await _get_providers_with_config(provider_type)
|
||||||
|
|
||||||
if not providers:
|
if not providers:
|
||||||
raise ValueError(f"No {provider_type} providers configured.")
|
raise ValueError(f"No {provider_type} providers configured.")
|
||||||
|
|||||||
434
backend/app/services/story_service.py
Normal file
434
backend/app/services/story_service.py
Normal file
@@ -0,0 +1,434 @@
|
|||||||
|
"""Story business logic service."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from sqlalchemy import select, desc
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import joinedload
|
||||||
|
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.db.models import ChildProfile, Story, StoryUniverse
|
||||||
|
from app.schemas.story_schemas import (
|
||||||
|
GenerateRequest,
|
||||||
|
StorybookRequest,
|
||||||
|
FullStoryResponse,
|
||||||
|
StorybookResponse,
|
||||||
|
StorybookPageResponse,
|
||||||
|
AchievementItem,
|
||||||
|
)
|
||||||
|
from app.services.memory_service import build_enhanced_memory_context
|
||||||
|
from app.services.provider_router import (
|
||||||
|
generate_story_content,
|
||||||
|
generate_image,
|
||||||
|
generate_storybook,
|
||||||
|
)
|
||||||
|
from app.tasks.achievements import extract_story_achievements
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_profile_and_universe(
|
||||||
|
profile_id: str | None,
|
||||||
|
universe_id: str | None,
|
||||||
|
user_id: str,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> tuple[str | None, str | None]:
|
||||||
|
"""Validate child profile and universe ownership/relationship."""
|
||||||
|
if not profile_id and not universe_id:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
if profile_id:
|
||||||
|
result = await db.execute(
|
||||||
|
select(ChildProfile).where(
|
||||||
|
ChildProfile.id == profile_id,
|
||||||
|
ChildProfile.user_id == user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
profile = result.scalar_one_or_none()
|
||||||
|
if not profile:
|
||||||
|
raise HTTPException(status_code=404, detail="孩子档案不存在")
|
||||||
|
|
||||||
|
if universe_id:
|
||||||
|
result = await db.execute(
|
||||||
|
select(StoryUniverse)
|
||||||
|
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
|
||||||
|
.where(
|
||||||
|
StoryUniverse.id == universe_id,
|
||||||
|
ChildProfile.user_id == user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
universe = result.scalar_one_or_none()
|
||||||
|
if not universe:
|
||||||
|
raise HTTPException(status_code=404, detail="故事宇宙不存在")
|
||||||
|
if profile_id and universe.child_profile_id != profile_id:
|
||||||
|
raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配")
|
||||||
|
if not profile_id:
|
||||||
|
profile_id = universe.child_profile_id
|
||||||
|
|
||||||
|
return profile_id, universe_id
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_and_save_story(
|
||||||
|
request: GenerateRequest,
|
||||||
|
user_id: str,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> Story:
|
||||||
|
"""Generate generic story content and save to DB."""
|
||||||
|
# 1. Validate
|
||||||
|
profile_id, universe_id = await validate_profile_and_universe(
|
||||||
|
request.child_profile_id, request.universe_id, user_id, db
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Build Context
|
||||||
|
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
||||||
|
|
||||||
|
# 3. Generate
|
||||||
|
try:
|
||||||
|
result = await generate_story_content(
|
||||||
|
input_type=request.type,
|
||||||
|
data=request.data,
|
||||||
|
education_theme=request.education_theme,
|
||||||
|
memory_context=memory_context,
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
raise HTTPException(status_code=502, detail="Story generation failed, please try again.") from exc
|
||||||
|
|
||||||
|
# 4. Save
|
||||||
|
story = Story(
|
||||||
|
user_id=user_id,
|
||||||
|
child_profile_id=profile_id,
|
||||||
|
universe_id=universe_id,
|
||||||
|
title=result.title,
|
||||||
|
story_text=result.story_text,
|
||||||
|
cover_prompt=result.cover_prompt_suggestion,
|
||||||
|
mode=result.mode,
|
||||||
|
)
|
||||||
|
db.add(story)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(story)
|
||||||
|
|
||||||
|
# 5. Trigger Async Tasks
|
||||||
|
if universe_id:
|
||||||
|
extract_story_achievements.delay(story.id, universe_id)
|
||||||
|
|
||||||
|
return story
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_full_story_service(
|
||||||
|
request: GenerateRequest,
|
||||||
|
user_id: str,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> FullStoryResponse:
|
||||||
|
"""Generate story with parallel image generation."""
|
||||||
|
# 1. Generate text part
|
||||||
|
# We can reuse logic or call generate_story_content directly if we want finer control
|
||||||
|
# reusing generate_and_save_story to ensure consistency (it handles validation + saving)
|
||||||
|
story = await generate_and_save_story(request, user_id, db)
|
||||||
|
|
||||||
|
# 2. Generate Image (Parallel/Async step in this flow)
|
||||||
|
image_url: str | None = None
|
||||||
|
errors: dict[str, str | None] = {}
|
||||||
|
|
||||||
|
if story.cover_prompt:
|
||||||
|
try:
|
||||||
|
image_url = await generate_image(story.cover_prompt, db=db)
|
||||||
|
story.image_url = image_url
|
||||||
|
await db.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
errors["image"] = str(exc)
|
||||||
|
logger.warning("image_generation_failed", story_id=story.id, error=str(exc))
|
||||||
|
|
||||||
|
return FullStoryResponse(
|
||||||
|
id=story.id,
|
||||||
|
title=story.title,
|
||||||
|
story_text=story.story_text,
|
||||||
|
cover_prompt=story.cover_prompt,
|
||||||
|
image_url=image_url,
|
||||||
|
audio_ready=False,
|
||||||
|
mode=story.mode,
|
||||||
|
errors=errors,
|
||||||
|
child_profile_id=story.child_profile_id,
|
||||||
|
universe_id=story.universe_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_storybook_service(
|
||||||
|
request: StorybookRequest,
|
||||||
|
user_id: str,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> StorybookResponse:
|
||||||
|
"""Generate storybook with parallel image generation for pages."""
|
||||||
|
# 1. Validate
|
||||||
|
profile_id, universe_id = await validate_profile_and_universe(
|
||||||
|
request.child_profile_id, request.universe_id, user_id, db
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"storybook_request",
|
||||||
|
user_id=user_id,
|
||||||
|
keywords=request.keywords,
|
||||||
|
page_count=request.page_count,
|
||||||
|
profile_id=profile_id,
|
||||||
|
universe_id=universe_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Context
|
||||||
|
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
||||||
|
|
||||||
|
# 3. Generate Text Structure
|
||||||
|
try:
|
||||||
|
storybook = await generate_storybook(
|
||||||
|
keywords=request.keywords,
|
||||||
|
page_count=request.page_count,
|
||||||
|
education_theme=request.education_theme,
|
||||||
|
memory_context=memory_context,
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("storybook_generation_failed", error=str(e))
|
||||||
|
raise HTTPException(status_code=500, detail=f"故事书生成失败: {e}")
|
||||||
|
|
||||||
|
# 4. Parallel Image Generation
|
||||||
|
final_cover_url = storybook.cover_url
|
||||||
|
if request.generate_images:
|
||||||
|
logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages))
|
||||||
|
|
||||||
|
tasks = []
|
||||||
|
|
||||||
|
# Cover Task
|
||||||
|
async def _gen_cover():
|
||||||
|
if storybook.cover_prompt and not storybook.cover_url:
|
||||||
|
try:
|
||||||
|
return await generate_image(storybook.cover_prompt, db=db)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("cover_gen_failed", error=str(e))
|
||||||
|
return storybook.cover_url
|
||||||
|
tasks.append(_gen_cover())
|
||||||
|
|
||||||
|
# Page Tasks
|
||||||
|
async def _gen_page(page):
|
||||||
|
if page.image_prompt and not page.image_url:
|
||||||
|
try:
|
||||||
|
url = await generate_image(page.image_prompt, db=db)
|
||||||
|
page.image_url = url
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("page_gen_failed", page=page.page_number, error=str(e))
|
||||||
|
|
||||||
|
for page in storybook.pages:
|
||||||
|
tasks.append(_gen_page(page))
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
# Update cover result
|
||||||
|
cover_res = results[0]
|
||||||
|
if isinstance(cover_res, str):
|
||||||
|
final_cover_url = cover_res
|
||||||
|
|
||||||
|
logger.info("storybook_parallel_generation_complete")
|
||||||
|
|
||||||
|
# 5. Save to DB
|
||||||
|
pages_data = [
|
||||||
|
{
|
||||||
|
"page_number": p.page_number,
|
||||||
|
"text": p.text,
|
||||||
|
"image_prompt": p.image_prompt,
|
||||||
|
"image_url": p.image_url,
|
||||||
|
}
|
||||||
|
for p in storybook.pages
|
||||||
|
]
|
||||||
|
|
||||||
|
story = Story(
|
||||||
|
user_id=user_id,
|
||||||
|
child_profile_id=profile_id,
|
||||||
|
universe_id=universe_id,
|
||||||
|
title=storybook.title,
|
||||||
|
mode="storybook",
|
||||||
|
pages=pages_data,
|
||||||
|
story_text=None,
|
||||||
|
cover_prompt=storybook.cover_prompt,
|
||||||
|
image_url=final_cover_url,
|
||||||
|
)
|
||||||
|
db.add(story)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(story)
|
||||||
|
|
||||||
|
if universe_id:
|
||||||
|
extract_story_achievements.delay(story.id, universe_id)
|
||||||
|
|
||||||
|
# 6. Build Response
|
||||||
|
response_pages = [
|
||||||
|
StorybookPageResponse(
|
||||||
|
page_number=p["page_number"],
|
||||||
|
text=p["text"],
|
||||||
|
image_prompt=p["image_prompt"],
|
||||||
|
image_url=p.get("image_url"),
|
||||||
|
)
|
||||||
|
for p in pages_data
|
||||||
|
]
|
||||||
|
|
||||||
|
return StorybookResponse(
|
||||||
|
id=story.id,
|
||||||
|
title=storybook.title,
|
||||||
|
main_character=storybook.main_character,
|
||||||
|
art_style=storybook.art_style,
|
||||||
|
pages=response_pages,
|
||||||
|
cover_prompt=storybook.cover_prompt,
|
||||||
|
cover_url=final_cover_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Missing Endpoints Logic (for Issue #5) ====================
|
||||||
|
|
||||||
|
async def list_stories(
|
||||||
|
user_id: str,
|
||||||
|
limit: int,
|
||||||
|
offset: int,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> list[Story]:
|
||||||
|
"""List stories for user."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(Story)
|
||||||
|
.where(Story.user_id == user_id)
|
||||||
|
.order_by(desc(Story.created_at))
|
||||||
|
.offset(offset)
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_story_detail(
|
||||||
|
story_id: int,
|
||||||
|
user_id: str,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> Story:
|
||||||
|
"""Get story detail."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(Story).where(Story.id == story_id, Story.user_id == user_id)
|
||||||
|
)
|
||||||
|
story = result.scalar_one_or_none()
|
||||||
|
if not story:
|
||||||
|
raise HTTPException(status_code=404, detail="Story not found")
|
||||||
|
return story
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_story(
|
||||||
|
story_id: int,
|
||||||
|
user_id: str,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> None:
|
||||||
|
"""Delete a story."""
|
||||||
|
story = await get_story_detail(story_id, user_id, db)
|
||||||
|
await db.delete(story)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def create_story_from_result(
|
||||||
|
result, # StoryOutput
|
||||||
|
user_id: str,
|
||||||
|
profile_id: str | None,
|
||||||
|
universe_id: str | None,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> Story:
|
||||||
|
"""Save a generated story to DB (helper for stream endpoint)."""
|
||||||
|
story = Story(
|
||||||
|
user_id=user_id,
|
||||||
|
child_profile_id=profile_id,
|
||||||
|
universe_id=universe_id,
|
||||||
|
title=result.title,
|
||||||
|
story_text=result.story_text,
|
||||||
|
cover_prompt=result.cover_prompt_suggestion,
|
||||||
|
mode=result.mode,
|
||||||
|
)
|
||||||
|
db.add(story)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(story)
|
||||||
|
|
||||||
|
if universe_id:
|
||||||
|
extract_story_achievements.delay(story.id, universe_id)
|
||||||
|
|
||||||
|
return story
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_story_cover(
|
||||||
|
story_id: int,
|
||||||
|
user_id: str,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> str:
|
||||||
|
"""Generate cover image for an existing story."""
|
||||||
|
story = await get_story_detail(story_id, user_id, db)
|
||||||
|
|
||||||
|
if not story.cover_prompt:
|
||||||
|
raise HTTPException(status_code=400, detail="Story has no cover prompt")
|
||||||
|
|
||||||
|
try:
|
||||||
|
image_url = await generate_image(story.cover_prompt, db=db)
|
||||||
|
story.image_url = image_url
|
||||||
|
await db.commit()
|
||||||
|
return image_url
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("cover_generation_failed", story_id=story_id, error=str(e))
|
||||||
|
raise HTTPException(status_code=500, detail=f"Image generation failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_story_audio(
|
||||||
|
story_id: int,
|
||||||
|
user_id: str,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> bytes:
|
||||||
|
"""Generate audio for a story."""
|
||||||
|
story = await get_story_detail(story_id, user_id, db)
|
||||||
|
|
||||||
|
if not story.story_text:
|
||||||
|
raise HTTPException(status_code=400, detail="Story has no text")
|
||||||
|
|
||||||
|
# TODO: Check if audio is already cached/saved?
|
||||||
|
# For now, generate on the fly via provider
|
||||||
|
from app.services.provider_router import text_to_speech
|
||||||
|
|
||||||
|
try:
|
||||||
|
audio_data = await text_to_speech(story.story_text, db=db)
|
||||||
|
return audio_data
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("audio_generation_failed", story_id=story_id, error=str(e))
|
||||||
|
raise HTTPException(status_code=500, detail=f"Audio generation failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_story_achievements(
|
||||||
|
story_id: int,
|
||||||
|
user_id: str,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> list[AchievementItem]:
|
||||||
|
"""Get achievements unlocked by a specific story."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(Story)
|
||||||
|
.options(joinedload(Story.story_universe))
|
||||||
|
.where(Story.id == story_id, Story.user_id == user_id)
|
||||||
|
)
|
||||||
|
story = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not story:
|
||||||
|
raise HTTPException(status_code=404, detail="Story not found")
|
||||||
|
|
||||||
|
if not story.universe_id or not story.story_universe:
|
||||||
|
return []
|
||||||
|
|
||||||
|
universe = story.story_universe
|
||||||
|
if not universe.achievements:
|
||||||
|
return []
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for ach in universe.achievements:
|
||||||
|
if isinstance(ach, dict) and ach.get("source_story_id") == story_id:
|
||||||
|
results.append(AchievementItem(
|
||||||
|
type=ach.get("type", "Unknown"),
|
||||||
|
description=ach.get("description", ""),
|
||||||
|
obtained_at=ach.get("obtained_at")
|
||||||
|
))
|
||||||
|
return results
|
||||||
|
|
||||||
@@ -1,93 +0,0 @@
|
|||||||
# 梦语织机 (DreamWeaver) 记忆系统升级 PRD
|
|
||||||
> 版本: v1.0 | 状态: 规划中 | 优先级: High
|
|
||||||
|
|
||||||
## 1. 核心愿景 (Vision)
|
|
||||||
|
|
||||||
将当前的"数据存储"升级为有温度的**"情感连接系统"**。
|
|
||||||
我们不只是在记住数据,而是在**维护孩子与故事世界的关系**。让每一个故事不再是孤立的碎片,而是构建孩子专属"故事宇宙"的砖瓦。
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 2. 产品痛点与解决方案
|
|
||||||
|
|
||||||
| 用户角色 | 核心痛点 | 解决方案 | 预期价值 |
|
|
||||||
|---------|---------|---------|---------|
|
|
||||||
| **孩子** | "上次的小兔子怎么不认识我了?" <br> 故事之间缺乏连续性,只有单次体验。 | **角色一致性与记忆注入** <br> 故事开头主动提及往事,角色性格延续。 | 建立情感依恋,提升沉浸感。 |
|
|
||||||
| **家长** | "这App除了生成故事还能干嘛?" <br> 无法感知产品的长期教育价值。 | **显性化成长轨迹** <br> 词汇量统计、主题变化、成就徽章可视化。 | 提高付费意愿,提供社交货币。 |
|
|
||||||
| **平台** | 用户用完即走,缺乏留存壁垒。 | **沉没成本与情感资产** <br> 积累的记忆越多,越舍不得离开。 | 提升长期留存率 (LTV)。 |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 3. 功能架构:记忆分层模型
|
|
||||||
|
|
||||||
### 3.1 层级 1: 核心档案 (Identity Layer)
|
|
||||||
*性质:永久、静态、显性*
|
|
||||||
* **数据**: 姓名、年龄、性别。
|
|
||||||
* **输入**: 家长在 Onboarding 阶段手动输入。
|
|
||||||
* **作用**: 决定故事的基础适龄性和称呼。
|
|
||||||
|
|
||||||
### 3.2 层级 2: 故事宇宙 (Universe Layer)
|
|
||||||
*性质:长期、动态积累、半显性*
|
|
||||||
* **主角设定**: 姓名、性格特征(勇敢/害羞)、外貌特征(戴眼镜/卷发)。
|
|
||||||
* **常驻配角**: 从随机故事中涌现出的固定伙伴(如"爱吃胡萝卜的松鼠奇奇")。
|
|
||||||
* **世界观**: 故事发生的背景(魔法森林、未来城市、海底世界)。
|
|
||||||
* **成就系统**: 孩子获得的虚拟奖励(勇气勋章、小小探险家)。
|
|
||||||
|
|
||||||
### 3.3 层级 3: 工作记忆 (Working Memory)
|
|
||||||
*性质:短期、自动衰减、隐性*
|
|
||||||
* **关键情节**: 最近 3 个故事的结局和核心冲突。
|
|
||||||
* **情感标记**: 孩子对特定内容的反应(根据“重播”、“跳过”推断)。
|
|
||||||
* **新学词汇**: 故事中出现的高级词汇。
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 4. 关键功能特性 (Feature Specs)
|
|
||||||
|
|
||||||
### 4.1 智能开场白 (Memory Injection)
|
|
||||||
在生成新故事时,Prompt 必须包含一段"记忆唤醒"指令。
|
|
||||||
* **示例**: "小明,还记得上周我们帮小松鼠找回了松果吗?今天,小松鼠带来了一位新朋友..."
|
|
||||||
* **策略**: 提取权重最高的 Top 3 记忆注入 Prompt。
|
|
||||||
|
|
||||||
### 4.2 成长时间轴 (Growth Timeline)
|
|
||||||
一个可视化的 H5 页面或 App 模块,以时间轴形式展示里程碑。
|
|
||||||
* **节点类型**:
|
|
||||||
* 🌟 **初次相遇**: 创建角色的第一天。
|
|
||||||
* 📖 **阅读打卡**: 累计阅读 10/50/100 本。
|
|
||||||
* 🏅 **获得成就**: 获得"诚实勋章"。
|
|
||||||
* 🧠 **能力解锁**: 第一次阅读"科幻"题材。
|
|
||||||
|
|
||||||
### 4.3 成就仪式感 (Achievement Ceremony)
|
|
||||||
* **触发**: 故事生成并分析后,如果获得新成就。
|
|
||||||
* **表现**: 弹窗动画 + 音效 + "恭喜获得 [勇气] 徽章"。
|
|
||||||
* **分享**:允许生成带二维码的成就海报。
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 5. 记忆类型扩展 (Memory Types)
|
|
||||||
|
|
||||||
| 类型 Key | 描述 | 来源 | 过期策略 |
|
|
||||||
|---------|------|------|---------|
|
|
||||||
| `recent_story` | 最近读过的故事梗概 | 阅读事件 | 30天衰减 |
|
|
||||||
| `favorite_character` | 孩子喜欢的角色 | 重播/高评分 | 长期有效 |
|
|
||||||
| `scary_element` | 孩子害怕/不喜欢的元素 | 跳过/负反馈 | 长期有效 (避雷) |
|
|
||||||
| `vocabulary_growth` | 新掌握的词汇 | 故事分析 | 90天衰减 |
|
|
||||||
| `emotional_highlight` | 高光时刻 (如: 特别开心的情节) | 互动数据 | 60天衰减 |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 6. 实施路线图 (Roadmap)
|
|
||||||
|
|
||||||
### Phase 1: 基础建设 (v0.3.0)
|
|
||||||
* [x] 数据库 `MemoryItem` 表 (已存在)。
|
|
||||||
* [ ] 扩展 `MemoryItem` 类型字段,支持更多维度。
|
|
||||||
* [ ] 优化 `_build_memory_context`,支持更自然的 Prompt 注入。
|
|
||||||
* [ ] 前端:简单的"近期回忆"展示列表。
|
|
||||||
|
|
||||||
### Phase 2: 可视化与成就 (v0.4.0)
|
|
||||||
* [ ] 实现"成就提取器" (Achievement Extractor) 的闭环通知。
|
|
||||||
* [ ] 前端:开发"我的成就"和"成长时间轴"页面。
|
|
||||||
* [ ] 增加故事开场白的动态生成逻辑。
|
|
||||||
|
|
||||||
### Phase 3: 深度智能 (v0.5.0+)
|
|
||||||
* [ ] 引入向量数据库,实现基于语义的记忆检索 (不仅是时间最近)。
|
|
||||||
* [ ] 情感分析模型:分析用户行为推断情感倾向。
|
|
||||||
98
backend/docs/refactoring_plan.md
Normal file
98
backend/docs/refactoring_plan.md
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
# DreamWeaver 重构实施计划
|
||||||
|
|
||||||
|
## 1. 概述
|
||||||
|
|
||||||
|
本文档基于对当前架构的深入分析,制定了从稳定性、可维护性到可扩展性的分阶段重构计划。
|
||||||
|
|
||||||
|
**目标**:
|
||||||
|
- **短期**:解决单点故障风险,优化开发体验,清理关键技术债。
|
||||||
|
- **中期**:提升系统高可用能力,增强监控与可观测性。
|
||||||
|
- **长期**:架构演进,支持大规模并发与复杂业务场景。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. 短期优化计划 (1-2周)
|
||||||
|
|
||||||
|
**重点**:消除即时风险,提升部署效率。
|
||||||
|
|
||||||
|
### 2.1 统一镜像构建 (High Priority)
|
||||||
|
目前 `backend`, `backend-admin`, `worker`, `celery-beat` 重复构建 4 次,浪费资源且镜像版本可能不一致。
|
||||||
|
|
||||||
|
- **Action Items**:
|
||||||
|
- [ ] 修改 `backend/Dockerfile` 为通用基础镜像。
|
||||||
|
- [ ] 更新 `docker-compose.yml`,定义 `backend-base` 服务或使用 `image` 标签共享镜像。
|
||||||
|
- [ ] 确保所有 Python 服务共用同一构建产物,仅启动命令不同。
|
||||||
|
|
||||||
|
### 2.2 修复 Provider 缓存与限流 (High Priority)
|
||||||
|
内存缓存 (`TTLCache`, `_latency_cache`) 在多进程/多实例下失效。
|
||||||
|
|
||||||
|
- **Action Items**:
|
||||||
|
- [ ] 引入 Redis 作为共享缓存后端。
|
||||||
|
- [ ] 重构 `_load_provider_cache`,将 Provider 配置缓存至 Redis。
|
||||||
|
- [ ] 重构 `stories.py` 中的限流逻辑,使用 `redis-cell` 或简单的 Redis 计数器替代 `TTLCache`。
|
||||||
|
|
||||||
|
### 2.3 拆分 `stories.py` (Medium Priority)
|
||||||
|
`app/api/stories.py` 超过 600 行,包含 API 定义、业务逻辑、验证逻辑,维护困难。
|
||||||
|
|
||||||
|
- **Action Items**:
|
||||||
|
- [ ] 创建 `app/services/story_service.py`,迁移生成、润色、PDF生成等核心逻辑。
|
||||||
|
- [ ] 创建 `app/schemas/story_schema.py`,迁移 Pydantic 模型(`GenerateRequest`, `StoryResponse` 等)。
|
||||||
|
- [ ] API 层 `stories.py` 仅保留路由定义和依赖注入,调用 Service 层。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. 中期优化计划 (1-2月)
|
||||||
|
|
||||||
|
**重点**:高可用 (HA) 与系统韧性。
|
||||||
|
|
||||||
|
### 3.1 数据库高可用 (Critical)
|
||||||
|
当前 PostgreSQL 为单点,且 Admin/User 混合使用。
|
||||||
|
|
||||||
|
- **Action Items**:
|
||||||
|
- [ ] 部署 PostgreSQL 主从复制 (Master-Slave)。
|
||||||
|
- [ ] 配置 `PgBouncer` 或 SQLAlchemy 读写分离,减轻主库压力。
|
||||||
|
- [ ] 实施数据库自动备份策略 (如 `pg_dump` 定时上传 S3)。
|
||||||
|
|
||||||
|
### 3.2 消息队列高可用 (Critical)
|
||||||
|
Redis 单点故障将导致 Celery 任务全盘停摆。
|
||||||
|
|
||||||
|
- **Action Items**:
|
||||||
|
- [ ] 迁移至 Redis Sentinel 或 Redis Cluster 模式。
|
||||||
|
- [ ] 更新 Celery 配置以支持 Sentinel/Cluster 连接串。
|
||||||
|
|
||||||
|
### 3.3 增强可观测性 (Important)
|
||||||
|
目前仅有简单的日志,缺乏系统级指标。
|
||||||
|
|
||||||
|
- **Action Items**:
|
||||||
|
- [ ] 集成 Prometheus Client,暴露 `/metrics` 端点。
|
||||||
|
- [ ] 部署 Grafana + Prometheus,监控 API 延迟、QPS、Celery 队列积压情况。
|
||||||
|
- [ ] 完善 `ProviderMetrics`,增加可视化大盘,实时监控 AI 供应商的成本与成功率。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. 长期架构演进 (季度规划)
|
||||||
|
|
||||||
|
**重点**:业务解耦与规模化。
|
||||||
|
|
||||||
|
### 4.1 统一 API 网关
|
||||||
|
- **当前**:前端直连后端端口,CORS 配置分散。
|
||||||
|
- **演进**:引入 Traefik 或 Nginx 作为统一网关,管理路由、SSL、全局限流、统一鉴权。
|
||||||
|
|
||||||
|
### 4.2 前端工程合并
|
||||||
|
- **当前**:User App 和 Admin Console 是完全独立的两个项目,但在组件和工具链上高度重复。
|
||||||
|
- **演进**:使用一种 Monorepo 策略或基于路由的单一应用策略,共享组件库和类型定义,减少维护成本。
|
||||||
|
|
||||||
|
### 4.3 事件驱动架构完善
|
||||||
|
- **当前**:部分业务逻辑耦合在 API 中。
|
||||||
|
- **演进**:扩展事件总线,将“阅读记录”、“成就解锁”、“通知推送”等非核心链路完全异步化,通过 Domain Events 解耦。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. 实施路线图
|
||||||
|
|
||||||
|
| 阶段 | 时间估算 | 关键里程碑 |
|
||||||
|
| :--- | :--- | :--- |
|
||||||
|
| **Phase 1: 基础夯实** | Week 1-2 | Docker 构建优化上线,Redis 替代内存缓存。 |
|
||||||
|
| **Phase 2: 代码重构** | Week 3-4 | `stories.py` 拆分完成,Service 层建立。 |
|
||||||
|
| **Phase 3: 高可用建设** | Month 2 | 数据库与 Redis 实现主备/集群模式。 |
|
||||||
|
| **Phase 4: 监控体系** | Month 2 | Grafana 监控大盘上线,关键指标报警配置完毕。 |
|
||||||
@@ -121,7 +121,7 @@ def mock_text_provider():
|
|||||||
cover_prompt_suggestion="A cute rabbit",
|
cover_prompt_suggestion="A cute rabbit",
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch("app.api.stories.generate_story_content", new_callable=AsyncMock) as mock:
|
with patch("app.services.story_service.generate_story_content", new_callable=AsyncMock) as mock:
|
||||||
mock.return_value = mock_result
|
mock.return_value = mock_result
|
||||||
yield mock
|
yield mock
|
||||||
|
|
||||||
@@ -129,7 +129,7 @@ def mock_text_provider():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_image_provider():
|
def mock_image_provider():
|
||||||
"""Mock 图像生成。"""
|
"""Mock 图像生成。"""
|
||||||
with patch("app.api.stories.generate_image", new_callable=AsyncMock) as mock:
|
with patch("app.services.story_service.generate_image", new_callable=AsyncMock) as mock:
|
||||||
mock.return_value = "https://example.com/image.png"
|
mock.return_value = "https://example.com/image.png"
|
||||||
yield mock
|
yield mock
|
||||||
|
|
||||||
@@ -137,7 +137,7 @@ def mock_image_provider():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_tts_provider():
|
def mock_tts_provider():
|
||||||
"""Mock TTS。"""
|
"""Mock TTS。"""
|
||||||
with patch("app.api.stories.text_to_speech", new_callable=AsyncMock) as mock:
|
with patch("app.services.provider_router.text_to_speech", new_callable=AsyncMock) as mock:
|
||||||
mock.return_value = b"fake-audio-bytes"
|
mock.return_value = b"fake-audio-bytes"
|
||||||
yield mock
|
yield mock
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,6 @@ class TestStoryGenerate:
|
|||||||
assert data["mode"] == "generated"
|
assert data["mode"] == "generated"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="GET /api/stories (列表) 端点尚未实现")
|
|
||||||
class TestStoryList:
|
class TestStoryList:
|
||||||
"""故事列表测试。"""
|
"""故事列表测试。"""
|
||||||
|
|
||||||
@@ -94,7 +93,6 @@ class TestStoryList:
|
|||||||
assert len(data) == 0
|
assert len(data) == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="GET /api/stories/{id} (详情) 端点尚未实现")
|
|
||||||
class TestStoryDetail:
|
class TestStoryDetail:
|
||||||
"""故事详情测试。"""
|
"""故事详情测试。"""
|
||||||
|
|
||||||
@@ -118,7 +116,6 @@ class TestStoryDetail:
|
|||||||
assert data["story_text"] == test_story.story_text
|
assert data["story_text"] == test_story.story_text
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="DELETE /api/stories/{id} (删除) 端点尚未实现")
|
|
||||||
class TestStoryDelete:
|
class TestStoryDelete:
|
||||||
"""故事删除测试。"""
|
"""故事删除测试。"""
|
||||||
|
|
||||||
@@ -168,7 +165,6 @@ class TestRateLimit:
|
|||||||
assert "Too many requests" in response.json()["detail"]
|
assert "Too many requests" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="POST /api/image/generate/{id} 端点尚未实现")
|
|
||||||
class TestImageGenerate:
|
class TestImageGenerate:
|
||||||
"""封面图片生成测试。"""
|
"""封面图片生成测试。"""
|
||||||
|
|
||||||
@@ -183,7 +179,6 @@ class TestImageGenerate:
|
|||||||
assert response.status_code == 404
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="GET /api/audio/{id} 端点尚未实现")
|
|
||||||
class TestAudio:
|
class TestAudio:
|
||||||
"""语音朗读测试。"""
|
"""语音朗读测试。"""
|
||||||
|
|
||||||
@@ -233,7 +228,7 @@ class TestGenerateFull:
|
|||||||
|
|
||||||
def test_generate_full_image_failure(self, auth_client: TestClient, mock_text_provider):
|
def test_generate_full_image_failure(self, auth_client: TestClient, mock_text_provider):
|
||||||
"""图片生成失败时返回部分成功。"""
|
"""图片生成失败时返回部分成功。"""
|
||||||
with patch("app.api.stories.generate_image", new_callable=AsyncMock) as mock_img:
|
with patch("app.services.story_service.generate_image", new_callable=AsyncMock) as mock_img:
|
||||||
mock_img.side_effect = Exception("Image API error")
|
mock_img.side_effect = Exception("Image API error")
|
||||||
response = auth_client.post(
|
response = auth_client.post(
|
||||||
"/api/stories/generate/full",
|
"/api/stories/generate/full",
|
||||||
@@ -261,7 +256,6 @@ class TestGenerateFull:
|
|||||||
assert call_kwargs["education_theme"] == "勇气与友谊"
|
assert call_kwargs["education_theme"] == "勇气与友谊"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="POST /api/image/generate/{id} 端点尚未实现")
|
|
||||||
class TestImageGenerateSuccess:
|
class TestImageGenerateSuccess:
|
||||||
"""封面图片生成成功测试。"""
|
"""封面图片生成成功测试。"""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user