Refactor Phase 2: Split stories.py into Schema/Service/Controller, add missing endpoints, fix async bug

This commit is contained in:
zhangtuo
2026-02-10 17:14:54 +08:00
parent c351d16d3e
commit 9cdff18336
13 changed files with 881 additions and 612 deletions

View File

@@ -2,136 +2,41 @@
import asyncio
import json
import time
import uuid
from typing import AsyncGenerator, Literal
from typing import AsyncGenerator
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import Response
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from sse_starlette.sse import EventSourceResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import require_user
from app.core.logging import get_logger
from app.db.database import get_db
from app.db.models import ChildProfile, Story, StoryUniverse, User
from app.services.provider_router import (
generate_image,
generate_story_content,
generate_storybook,
text_to_speech,
)
from app.core.rate_limiter import check_rate_limit
from app.tasks.achievements import extract_story_achievements
from app.db.database import get_db
from app.db.models import User
from app.schemas.story_schemas import (
GenerateRequest,
StoryResponse,
FullStoryResponse,
StorybookRequest,
StorybookResponse,
StoryListItem,
AchievementItem,
)
from app.services import story_service
from app.services.memory_service import build_enhanced_memory_context
from app.services.provider_router import (
generate_story_content,
generate_image,
)
logger = get_logger(__name__)
router = APIRouter()
MAX_DATA_LENGTH = 2000
MAX_EDU_THEME_LENGTH = 200
MAX_TTS_LENGTH = 4000
RATE_LIMIT_WINDOW = 60 # seconds
RATE_LIMIT_REQUESTS = 10
class GenerateRequest(BaseModel):
"""Story generation request."""
type: Literal["keywords", "full_story"]
data: str = Field(..., min_length=1, max_length=MAX_DATA_LENGTH)
education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH)
child_profile_id: str | None = None
universe_id: str | None = None
class StoryResponse(BaseModel):
"""Story response."""
id: int
title: str
story_text: str
cover_prompt: str | None
image_url: str | None
mode: str
child_profile_id: str | None = None
universe_id: str | None = None
class StoryListItem(BaseModel):
"""Story list item."""
id: int
title: str
image_url: str | None
created_at: str
mode: str
class FullStoryResponse(BaseModel):
"""完整故事响应(含图片和音频状态)。"""
id: int
title: str
story_text: str
cover_prompt: str | None
image_url: str | None
audio_ready: bool
mode: str
errors: dict[str, str | None] = Field(default_factory=dict)
child_profile_id: str | None = None
universe_id: str | None = None
from app.services.memory_service import build_enhanced_memory_context
async def _validate_profile_and_universe(
request: GenerateRequest,
user: User,
db: AsyncSession,
) -> tuple[str | None, str | None]:
if not request.child_profile_id and not request.universe_id:
return None, None
profile_id = request.child_profile_id
universe_id = request.universe_id
if profile_id:
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="孩子档案不存在")
if universe_id:
result = await db.execute(
select(StoryUniverse)
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
.where(
StoryUniverse.id == universe_id,
ChildProfile.user_id == user.id,
)
)
universe = result.scalar_one_or_none()
if not universe:
raise HTTPException(status_code=404, detail="故事宇宙不存在")
if profile_id and universe.child_profile_id != profile_id:
raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配")
if not profile_id:
profile_id = universe.child_profile_id
return profile_id, universe_id
@router.post("/stories/generate", response_model=StoryResponse)
async def generate_story(
request: GenerateRequest,
@@ -140,47 +45,7 @@ async def generate_story(
):
"""Generate or enhance a story."""
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
try:
result = await generate_story_content(
input_type=request.type,
data=request.data,
education_theme=request.education_theme,
memory_context=memory_context,
)
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=502, detail="Story generation failed, please try again.")
story = Story(
user_id=user.id,
child_profile_id=profile_id,
universe_id=universe_id,
title=result.title,
story_text=result.story_text,
cover_prompt=result.cover_prompt_suggestion,
mode=result.mode,
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
return StoryResponse(
id=story.id,
title=story.title,
story_text=story.story_text,
cover_prompt=story.cover_prompt,
image_url=story.image_url,
mode=story.mode,
child_profile_id=story.child_profile_id,
universe_id=story.universe_id,
)
return await story_service.generate_and_save_story(request, user.id, db)
@router.post("/stories/generate/full", response_model=FullStoryResponse)
@@ -189,71 +54,9 @@ async def generate_story_full(
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""生成完整故事(故事 + 并行生成图片和音频)。
部分成功策略:故事必须成功,图片/音频失败不影响整体。
"""
"""Generate complete story (story + parallel image/audio generation)."""
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
# Step 1: 故事生成(必须成功)
try:
result = await generate_story_content(
input_type=request.type,
data=request.data,
education_theme=request.education_theme,
memory_context=memory_context,
)
except Exception as exc:
logger.error("story_generation_failed", error=str(exc))
raise HTTPException(status_code=502, detail="Story generation failed, please try again.")
# 保存故事
story = Story(
user_id=user.id,
child_profile_id=profile_id,
universe_id=universe_id,
title=result.title,
story_text=result.story_text,
cover_prompt=result.cover_prompt_suggestion,
mode=result.mode,
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
# Step 2: 生成封面图片(音频按需生成,避免浪费)
errors: dict[str, str | None] = {}
image_url: str | None = None
if story.cover_prompt:
try:
image_url = await generate_image(story.cover_prompt)
story.image_url = image_url
await db.commit()
except Exception as exc:
errors["image"] = str(exc)
logger.warning("image_generation_failed", story_id=story.id, error=str(exc))
# 注意:音频不在此处预生成,用户通过 /api/audio/{id} 按需获取
# 这样避免生成后丢弃造成的成本浪费
return FullStoryResponse(
id=story.id,
title=story.title,
story_text=story.story_text,
cover_prompt=story.cover_prompt,
image_url=image_url,
audio_ready=False, # 音频需要用户主动请求
mode=story.mode,
errors=errors,
child_profile_id=story.child_profile_id,
universe_id=story.universe_id,
)
return await story_service.generate_full_story_service(request, user.id, db)
@router.post("/stories/generate/stream")
@@ -263,53 +66,39 @@ async def generate_story_stream(
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""流式生成故事SSE
事件流程:
- started: 返回 story_id
- story_ready: 返回 title, content
- story_failed: 返回 error
- image_ready: 返回 image_url
- image_failed: 返回 error
- complete: 结束流
"""
"""流式生成故事SSE"""
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
profile_id, universe_id = await _validate_profile_and_universe(request, user, db)
# Validation
profile_id, universe_id = await story_service.validate_profile_and_universe(
request.child_profile_id, request.universe_id, user.id, db
)
# Build Context
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
async def event_generator() -> AsyncGenerator[dict, None]:
story_id = str(uuid.uuid4())
yield {"event": "started", "data": json.dumps({"story_id": story_id})}
# Step 1: 生成故事
# Step 1: Generate Content
try:
result = await generate_story_content(
input_type=request.type,
data=request.data,
education_theme=request.education_theme,
memory_context=memory_context,
db=db,
)
except Exception as e:
logger.error("sse_story_generation_failed", error=str(e))
yield {"event": "story_failed", "data": json.dumps({"error": str(e)})}
return
# 保存故事
story = Story(
user_id=user.id,
child_profile_id=profile_id,
universe_id=universe_id,
title=result.title,
story_text=result.story_text,
cover_prompt=result.cover_prompt_suggestion,
mode=result.mode,
# Save Story
story = await story_service.create_story_from_result(
result, user.id, profile_id, universe_id, db
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
yield {
"event": "story_ready",
@@ -324,10 +113,11 @@ async def generate_story_stream(
}),
}
# Step 2: 并行生成图片(音频按需)
# Step 2: Generate Image
if story.cover_prompt:
try:
image_url = await generate_image(story.cover_prompt)
# Direct call to provider router's generate_image, sharing db session
image_url = await generate_image(story.cover_prompt, db=db)
story.image_url = image_url
await db.commit()
yield {"event": "image_ready", "data": json.dumps({"image_url": image_url})}
@@ -340,217 +130,71 @@ async def generate_story_stream(
return EventSourceResponse(event_generator())
# ==================== Storybook API ====================
class StorybookRequest(BaseModel):
"""Storybook 生成请求。"""
keywords: str = Field(..., min_length=1, max_length=200)
page_count: int = Field(default=6, ge=4, le=12)
education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH)
generate_images: bool = Field(default=False, description="是否同时生成插图")
child_profile_id: str | None = None
universe_id: str | None = None
class StorybookPageResponse(BaseModel):
"""故事书单页响应。"""
page_number: int
text: str
image_prompt: str
image_url: str | None = None
class StorybookResponse(BaseModel):
"""故事书响应。"""
id: int | None = None
title: str
main_character: str
art_style: str
pages: list[StorybookPageResponse]
cover_prompt: str
cover_url: str | None = None
@router.post("/storybook/generate", response_model=StorybookResponse)
async def generate_storybook_api(
request: StorybookRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""生成分页故事书并保存。
返回故事书结构,包含每页文字和图像提示词。
"""
"""Generate storybook."""
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
# 验证档案和宇宙
# 复用 _validate_profile_and_universe 需要将 request 转换为 GenerateRequest 或稍微修改验证函数
# 这里我们直接手动验证,或重构验证函数。为了简单,手动调用部分逻辑。
# 构建临时的 GenerateRequest 用于验证验证函数签名(或者直接手动查库更好)
profile_id = request.child_profile_id
universe_id = request.universe_id
if profile_id:
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user.id,
)
)
if not result.scalar_one_or_none():
raise HTTPException(status_code=404, detail="孩子档案不存在")
if universe_id:
result = await db.execute(
select(StoryUniverse)
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
.where(
StoryUniverse.id == universe_id,
ChildProfile.user_id == user.id,
)
)
universe = result.scalar_one_or_none()
if not universe:
raise HTTPException(status_code=404, detail="故事宇宙不存在")
if profile_id and universe.child_profile_id != profile_id:
raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配")
if not profile_id:
profile_id = universe.child_profile_id
logger.info(
"storybook_request",
user_id=user.id,
keywords=request.keywords,
page_count=request.page_count,
profile_id=profile_id,
universe_id=universe_id,
)
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
try:
# 注意generate_storybook 目前可能不支持记忆上下文注入
# 我们需要看看 generate_storybook 的签名
# 如果不支持,记忆功能在绘本模式下暂不可用,但基本参数传递是支持的
storybook = await generate_storybook(
keywords=request.keywords,
page_count=request.page_count,
education_theme=request.education_theme,
memory_context=memory_context,
db=db,
)
except Exception as e:
logger.error("storybook_generation_failed", error=str(e))
raise HTTPException(status_code=500, detail=f"故事书生成失败: {e}")
# ==============================================================================
# 核心升级: 并行全量生成 (Parallel Full Rendering)
# ==============================================================================
final_cover_url = storybook.cover_url
if request.generate_images:
logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages))
# 1. 准备所有生图任务 (封面 + 所有内页)
tasks = []
# 封面任务
async def _gen_cover():
if storybook.cover_prompt and not storybook.cover_url:
try:
return await generate_image(storybook.cover_prompt, db=db)
except Exception as e:
logger.warning("cover_gen_failed", error=str(e))
return storybook.cover_url
tasks.append(_gen_cover())
# 内页任务
async def _gen_page(page):
if page.image_prompt and not page.image_url:
try:
url = await generate_image(page.image_prompt, db=db)
page.image_url = url
except Exception as e:
logger.warning("page_gen_failed", page=page.page_number, error=str(e))
for page in storybook.pages:
tasks.append(_gen_page(page))
# 2. 并发执行所有任务
# 使用 return_exceptions=True 防止单张失败影响整体
results = await asyncio.gather(*tasks, return_exceptions=True)
# 3. 更新封面结果 (results[0] 是封面任务的返回值)
cover_res = results[0]
if isinstance(cover_res, str):
final_cover_url = cover_res
logger.info("storybook_parallel_generation_complete")
# ==============================================================================
# 构建并保存 Story 对象
# 将 pages 对象转换为字典列表以存入 JSON 字段
pages_data = [
{
"page_number": p.page_number,
"text": p.text,
"image_prompt": p.image_prompt,
"image_url": p.image_url,
}
for p in storybook.pages
]
story = Story(
user_id=user.id,
child_profile_id=profile_id,
universe_id=universe_id,
title=storybook.title,
mode="storybook",
pages=pages_data, # 存入 JSON 字段
story_text=None, # 绘本模式下,主文本可为空,或者可以存个摘要
cover_prompt=storybook.cover_prompt,
image_url=final_cover_url,
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
# 构建响应 (使用更新后的 pages_data)
response_pages = [
StorybookPageResponse(
page_number=p["page_number"],
text=p["text"],
image_prompt=p["image_prompt"],
image_url=p.get("image_url"),
)
for p in pages_data
]
return StorybookResponse(
id=story.id,
title=storybook.title,
main_character=storybook.main_character,
art_style=storybook.art_style,
pages=response_pages,
cover_prompt=storybook.cover_prompt,
cover_url=final_cover_url,
)
return await story_service.generate_storybook_service(request, user.id, db)
class AchievementItem(BaseModel):
type: str
description: str
obtained_at: str | None = None
# ==================== Missing Endpoints (Issue #5) ====================
@router.get("/stories", response_model=list[StoryListItem])
async def list_stories(
limit: int = 20,
offset: int = 0,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""List stories."""
return await story_service.list_stories(user.id, limit, offset, db)
@router.get("/stories/{story_id}", response_model=StoryResponse)
async def get_story(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get story detail."""
return await story_service.get_story_detail(story_id, user.id, db)
@router.delete("/stories/{story_id}")
async def delete_story(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Delete story."""
await story_service.delete_story(story_id, user.id, db)
return {"message": "Deleted"}
@router.post("/image/generate/{story_id}")
async def generate_story_image(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Generate cover image for story."""
url = await story_service.generate_story_cover(story_id, user.id, db)
return {"image_url": url}
@router.get("/audio/{story_id}")
async def get_story_audio(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get story audio (MP3)."""
audio_bytes = await story_service.generate_story_audio(story_id, user.id, db)
return Response(content=audio_bytes, media_type="audio/mpeg")
@router.get("/stories/{story_id}/achievements", response_model=list[AchievementItem])
@@ -559,32 +203,5 @@ async def get_story_achievements(
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get achievements unlocked by a specific story."""
# 使用 joinedload 避免 N+1 查询
result = await db.execute(
select(Story)
.options(joinedload(Story.story_universe))
.where(Story.id == story_id, Story.user_id == user.id)
)
story = result.scalar_one_or_none()
if not story:
raise HTTPException(status_code=404, detail="Story not found")
if not story.universe_id or not story.story_universe:
return []
universe = story.story_universe
if not universe.achievements:
return []
results = []
for ach in universe.achievements:
if isinstance(ach, dict) and ach.get("source_story_id") == story_id:
results.append(AchievementItem(
type=ach.get("type", "Unknown"),
description=ach.get("description", ""),
obtained_at=ach.get("obtained_at")
))
return results
"""Get story achievements."""
return await story_service.get_story_achievements(story_id, user.id, db)