"""Story related APIs.""" import asyncio import json import time import uuid from typing import AsyncGenerator, Literal from cachetools import TTLCache from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import Response from pydantic import BaseModel, Field from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload from sse_starlette.sse import EventSourceResponse from app.core.deps import require_user from app.core.logging import get_logger from app.db.database import get_db from app.db.models import ChildProfile, Story, StoryUniverse, User from app.services.provider_router import ( generate_image, generate_story_content, generate_storybook, text_to_speech, ) from app.tasks.achievements import extract_story_achievements logger = get_logger(__name__) router = APIRouter() MAX_DATA_LENGTH = 2000 MAX_EDU_THEME_LENGTH = 200 MAX_TTS_LENGTH = 4000 RATE_LIMIT_WINDOW = 60 # seconds RATE_LIMIT_REQUESTS = 10 RATE_LIMIT_CACHE_SIZE = 10000 # 最大跟踪用户数 _request_log: TTLCache[str, list[float]] = TTLCache( maxsize=RATE_LIMIT_CACHE_SIZE, ttl=RATE_LIMIT_WINDOW * 2 ) def _check_rate_limit(user_id: str): now = time.time() timestamps = _request_log.get(user_id, []) timestamps = [t for t in timestamps if now - t <= RATE_LIMIT_WINDOW] if len(timestamps) >= RATE_LIMIT_REQUESTS: raise HTTPException(status_code=429, detail="Too many requests, please slow down.") timestamps.append(now) _request_log[user_id] = timestamps class GenerateRequest(BaseModel): """Story generation request.""" type: Literal["keywords", "full_story"] data: str = Field(..., min_length=1, max_length=MAX_DATA_LENGTH) education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH) child_profile_id: str | None = None universe_id: str | None = None class StoryResponse(BaseModel): """Story response.""" id: int title: str story_text: str cover_prompt: str | None image_url: str | None mode: str child_profile_id: str | None = None universe_id: str | None = None class StoryListItem(BaseModel): """Story list item.""" id: int title: str image_url: str | None created_at: str mode: str class FullStoryResponse(BaseModel): """完整故事响应(含图片和音频状态)。""" id: int title: str story_text: str cover_prompt: str | None image_url: str | None audio_ready: bool mode: str errors: dict[str, str | None] = Field(default_factory=dict) child_profile_id: str | None = None universe_id: str | None = None from app.services.memory_service import build_enhanced_memory_context async def _validate_profile_and_universe( request: GenerateRequest, user: User, db: AsyncSession, ) -> tuple[str | None, str | None]: if not request.child_profile_id and not request.universe_id: return None, None profile_id = request.child_profile_id universe_id = request.universe_id if profile_id: result = await db.execute( select(ChildProfile).where( ChildProfile.id == profile_id, ChildProfile.user_id == user.id, ) ) profile = result.scalar_one_or_none() if not profile: raise HTTPException(status_code=404, detail="孩子档案不存在") if universe_id: result = await db.execute( select(StoryUniverse) .join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id) .where( StoryUniverse.id == universe_id, ChildProfile.user_id == user.id, ) ) universe = result.scalar_one_or_none() if not universe: raise HTTPException(status_code=404, detail="故事宇宙不存在") if profile_id and universe.child_profile_id != profile_id: raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配") if not profile_id: profile_id = universe.child_profile_id return profile_id, universe_id @router.post("/stories/generate", response_model=StoryResponse) async def generate_story( request: GenerateRequest, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): """Generate or enhance a story.""" _check_rate_limit(user.id) profile_id, universe_id = await _validate_profile_and_universe(request, user, db) memory_context = await build_enhanced_memory_context(profile_id, universe_id, db) try: result = await generate_story_content( input_type=request.type, data=request.data, education_theme=request.education_theme, memory_context=memory_context, ) except HTTPException: raise except Exception: raise HTTPException(status_code=502, detail="Story generation failed, please try again.") story = Story( user_id=user.id, child_profile_id=profile_id, universe_id=universe_id, title=result.title, story_text=result.story_text, cover_prompt=result.cover_prompt_suggestion, mode=result.mode, ) db.add(story) await db.commit() await db.refresh(story) if universe_id: extract_story_achievements.delay(story.id, universe_id) return StoryResponse( id=story.id, title=story.title, story_text=story.story_text, cover_prompt=story.cover_prompt, image_url=story.image_url, mode=story.mode, child_profile_id=story.child_profile_id, universe_id=story.universe_id, ) @router.post("/stories/generate/full", response_model=FullStoryResponse) async def generate_story_full( request: GenerateRequest, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): """生成完整故事(故事 + 并行生成图片和音频)。 部分成功策略:故事必须成功,图片/音频失败不影响整体。 """ _check_rate_limit(user.id) profile_id, universe_id = await _validate_profile_and_universe(request, user, db) memory_context = await build_enhanced_memory_context(profile_id, universe_id, db) # Step 1: 故事生成(必须成功) try: result = await generate_story_content( input_type=request.type, data=request.data, education_theme=request.education_theme, memory_context=memory_context, ) except Exception as exc: logger.error("story_generation_failed", error=str(exc)) raise HTTPException(status_code=502, detail="Story generation failed, please try again.") # 保存故事 story = Story( user_id=user.id, child_profile_id=profile_id, universe_id=universe_id, title=result.title, story_text=result.story_text, cover_prompt=result.cover_prompt_suggestion, mode=result.mode, ) db.add(story) await db.commit() await db.refresh(story) if universe_id: extract_story_achievements.delay(story.id, universe_id) # Step 2: 生成封面图片(音频按需生成,避免浪费) errors: dict[str, str | None] = {} image_url: str | None = None if story.cover_prompt: try: image_url = await generate_image(story.cover_prompt) story.image_url = image_url await db.commit() except Exception as exc: errors["image"] = str(exc) logger.warning("image_generation_failed", story_id=story.id, error=str(exc)) # 注意:音频不在此处预生成,用户通过 /api/audio/{id} 按需获取 # 这样避免生成后丢弃造成的成本浪费 return FullStoryResponse( id=story.id, title=story.title, story_text=story.story_text, cover_prompt=story.cover_prompt, image_url=image_url, audio_ready=False, # 音频需要用户主动请求 mode=story.mode, errors=errors, child_profile_id=story.child_profile_id, universe_id=story.universe_id, ) @router.post("/stories/generate/stream") async def generate_story_stream( request: GenerateRequest, req: Request, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): """流式生成故事(SSE)。 事件流程: - started: 返回 story_id - story_ready: 返回 title, content - story_failed: 返回 error - image_ready: 返回 image_url - image_failed: 返回 error - complete: 结束流 """ _check_rate_limit(user.id) profile_id, universe_id = await _validate_profile_and_universe(request, user, db) memory_context = await build_enhanced_memory_context(profile_id, universe_id, db) async def event_generator() -> AsyncGenerator[dict, None]: story_id = str(uuid.uuid4()) yield {"event": "started", "data": json.dumps({"story_id": story_id})} # Step 1: 生成故事 try: result = await generate_story_content( input_type=request.type, data=request.data, education_theme=request.education_theme, memory_context=memory_context, ) except Exception as e: logger.error("sse_story_generation_failed", error=str(e)) yield {"event": "story_failed", "data": json.dumps({"error": str(e)})} return # 保存故事 story = Story( user_id=user.id, child_profile_id=profile_id, universe_id=universe_id, title=result.title, story_text=result.story_text, cover_prompt=result.cover_prompt_suggestion, mode=result.mode, ) db.add(story) await db.commit() await db.refresh(story) if universe_id: extract_story_achievements.delay(story.id, universe_id) yield { "event": "story_ready", "data": json.dumps({ "id": story.id, "title": story.title, "content": story.story_text, "cover_prompt": story.cover_prompt, "mode": story.mode, "child_profile_id": story.child_profile_id, "universe_id": story.universe_id, }), } # Step 2: 并行生成图片(音频按需) if story.cover_prompt: try: image_url = await generate_image(story.cover_prompt) story.image_url = image_url await db.commit() yield {"event": "image_ready", "data": json.dumps({"image_url": image_url})} except Exception as e: logger.warning("sse_image_generation_failed", story_id=story.id, error=str(e)) yield {"event": "image_failed", "data": json.dumps({"error": str(e)})} yield {"event": "complete", "data": json.dumps({"story_id": story.id})} return EventSourceResponse(event_generator()) # ==================== Storybook API ==================== class StorybookRequest(BaseModel): """Storybook 生成请求。""" keywords: str = Field(..., min_length=1, max_length=200) page_count: int = Field(default=6, ge=4, le=12) education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH) generate_images: bool = Field(default=False, description="是否同时生成插图") child_profile_id: str | None = None universe_id: str | None = None class StorybookPageResponse(BaseModel): """故事书单页响应。""" page_number: int text: str image_prompt: str image_url: str | None = None class StorybookResponse(BaseModel): """故事书响应。""" id: int | None = None title: str main_character: str art_style: str pages: list[StorybookPageResponse] cover_prompt: str cover_url: str | None = None @router.post("/storybook/generate", response_model=StorybookResponse) async def generate_storybook_api( request: StorybookRequest, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): """生成分页故事书并保存。 返回故事书结构,包含每页文字和图像提示词。 """ _check_rate_limit(user.id) # 验证档案和宇宙 # 复用 _validate_profile_and_universe 需要将 request 转换为 GenerateRequest 或稍微修改验证函数 # 这里我们直接手动验证,或重构验证函数。为了简单,手动调用部分逻辑。 # 构建临时的 GenerateRequest 用于验证验证函数签名(或者直接手动查库更好) profile_id = request.child_profile_id universe_id = request.universe_id if profile_id: result = await db.execute( select(ChildProfile).where( ChildProfile.id == profile_id, ChildProfile.user_id == user.id, ) ) if not result.scalar_one_or_none(): raise HTTPException(status_code=404, detail="孩子档案不存在") if universe_id: result = await db.execute( select(StoryUniverse) .join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id) .where( StoryUniverse.id == universe_id, ChildProfile.user_id == user.id, ) ) universe = result.scalar_one_or_none() if not universe: raise HTTPException(status_code=404, detail="故事宇宙不存在") if profile_id and universe.child_profile_id != profile_id: raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配") if not profile_id: profile_id = universe.child_profile_id logger.info( "storybook_request", user_id=user.id, keywords=request.keywords, page_count=request.page_count, profile_id=profile_id, universe_id=universe_id, ) memory_context = await build_enhanced_memory_context(profile_id, universe_id, db) try: # 注意:generate_storybook 目前可能不支持记忆上下文注入 # 我们需要看看 generate_storybook 的签名 # 如果不支持,记忆功能在绘本模式下暂不可用,但基本参数传递是支持的 storybook = await generate_storybook( keywords=request.keywords, page_count=request.page_count, education_theme=request.education_theme, memory_context=memory_context, db=db, ) except Exception as e: logger.error("storybook_generation_failed", error=str(e)) raise HTTPException(status_code=500, detail=f"故事书生成失败: {e}") # ============================================================================== # 核心升级: 并行全量生成 (Parallel Full Rendering) # ============================================================================== final_cover_url = storybook.cover_url if request.generate_images: logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages)) # 1. 准备所有生图任务 (封面 + 所有内页) tasks = [] # 封面任务 async def _gen_cover(): if storybook.cover_prompt and not storybook.cover_url: try: return await generate_image(storybook.cover_prompt, db=db) except Exception as e: logger.warning("cover_gen_failed", error=str(e)) return storybook.cover_url tasks.append(_gen_cover()) # 内页任务 async def _gen_page(page): if page.image_prompt and not page.image_url: try: url = await generate_image(page.image_prompt, db=db) page.image_url = url except Exception as e: logger.warning("page_gen_failed", page=page.page_number, error=str(e)) for page in storybook.pages: tasks.append(_gen_page(page)) # 2. 并发执行所有任务 # 使用 return_exceptions=True 防止单张失败影响整体 results = await asyncio.gather(*tasks, return_exceptions=True) # 3. 更新封面结果 (results[0] 是封面任务的返回值) cover_res = results[0] if isinstance(cover_res, str): final_cover_url = cover_res logger.info("storybook_parallel_generation_complete") # ============================================================================== # 构建并保存 Story 对象 # 将 pages 对象转换为字典列表以存入 JSON 字段 pages_data = [ { "page_number": p.page_number, "text": p.text, "image_prompt": p.image_prompt, "image_url": p.image_url, } for p in storybook.pages ] story = Story( user_id=user.id, child_profile_id=profile_id, universe_id=universe_id, title=storybook.title, mode="storybook", pages=pages_data, # 存入 JSON 字段 story_text=None, # 绘本模式下,主文本可为空,或者可以存个摘要 cover_prompt=storybook.cover_prompt, image_url=final_cover_url, ) db.add(story) await db.commit() await db.refresh(story) if universe_id: extract_story_achievements.delay(story.id, universe_id) # 构建响应 (使用更新后的 pages_data) response_pages = [ StorybookPageResponse( page_number=p["page_number"], text=p["text"], image_prompt=p["image_prompt"], image_url=p.get("image_url"), ) for p in pages_data ] return StorybookResponse( id=story.id, title=storybook.title, main_character=storybook.main_character, art_style=storybook.art_style, pages=response_pages, cover_prompt=storybook.cover_prompt, cover_url=final_cover_url, ) class AchievementItem(BaseModel): type: str description: str obtained_at: str | None = None @router.get("/stories/{story_id}/achievements", response_model=list[AchievementItem]) async def get_story_achievements( story_id: int, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): """Get achievements unlocked by a specific story.""" # 使用 joinedload 避免 N+1 查询 result = await db.execute( select(Story) .options(joinedload(Story.story_universe)) .where(Story.id == story_id, Story.user_id == user.id) ) story = result.scalar_one_or_none() if not story: raise HTTPException(status_code=404, detail="Story not found") if not story.universe_id or not story.story_universe: return [] universe = story.story_universe if not universe.achievements: return [] results = [] for ach in universe.achievements: if isinstance(ach, dict) and ach.get("source_story_id") == story_id: results.append(AchievementItem( type=ach.get("type", "Unknown"), description=ach.get("description", ""), obtained_at=ach.get("obtained_at") )) return results