"""Story related APIs.""" import json import uuid from typing import AsyncGenerator from fastapi import APIRouter, Depends, Response from sse_starlette.sse import EventSourceResponse from sqlalchemy.ext.asyncio import AsyncSession from app.core.deps import require_user from app.core.logging import get_logger from app.core.rate_limiter import check_rate_limit from app.db.database import get_db from app.db.models import User from app.schemas.story_schemas import ( AchievementItem, FullStoryResponse, GenerateRequest, StoryDetailResponse, StoryImageResponse, StoryListItem, StoryResponse, StorybookRequest, StorybookResponse, ) from app.services import story_service from app.services.memory_service import build_enhanced_memory_context from app.services.provider_router import ( generate_story_content, generate_image, ) from app.services.story_status import StoryAssetStatus, sync_story_status logger = get_logger(__name__) router = APIRouter() RATE_LIMIT_WINDOW = 60 # seconds RATE_LIMIT_REQUESTS = 10 @router.post("/stories/generate", response_model=StoryResponse) async def generate_story( request: GenerateRequest, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): """Generate or enhance a story.""" await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW) return await story_service.generate_and_save_story(request, user.id, db) @router.post("/stories/generate/full", response_model=FullStoryResponse) async def generate_story_full( request: GenerateRequest, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): """Generate complete story (story + parallel image/audio generation).""" await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW) return await story_service.generate_full_story_service(request, user.id, db) @router.post("/stories/generate/stream") async def generate_story_stream( request: GenerateRequest, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): """流式生成故事(SSE)。""" await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW) # Validation profile_id, universe_id = await story_service.validate_profile_and_universe( request.child_profile_id, request.universe_id, user.id, db ) # Build Context memory_context = await build_enhanced_memory_context(profile_id, universe_id, db) async def event_generator() -> AsyncGenerator[dict, None]: story_id = str(uuid.uuid4()) yield {"event": "started", "data": json.dumps({"story_id": story_id})} # Step 1: Generate Content try: result = await generate_story_content( input_type=request.type, data=request.data, education_theme=request.education_theme, memory_context=memory_context, db=db, ) except Exception as e: logger.error("sse_story_generation_failed", error=str(e)) yield {"event": "story_failed", "data": json.dumps({"error": str(e)})} return # Save Story story = await story_service.create_story_from_result( result, user.id, profile_id, universe_id, db ) yield { "event": "story_ready", "data": json.dumps({ "id": story.id, "title": story.title, "content": story.story_text, "cover_prompt": story.cover_prompt, "mode": story.mode, "child_profile_id": story.child_profile_id, "universe_id": story.universe_id, "generation_status": story.generation_status, "image_status": story.image_status, "audio_status": story.audio_status, "last_error": story.last_error, }), } # Step 2: Generate Image if story.cover_prompt: sync_story_status(story, image_status=StoryAssetStatus.GENERATING) await db.commit() try: # Direct call to provider router's generate_image, sharing db session image_url = await generate_image(story.cover_prompt, db=db) story.image_url = image_url sync_story_status( story, image_status=StoryAssetStatus.READY, ) await db.commit() yield { "event": "image_ready", "data": json.dumps( { "image_url": image_url, "generation_status": story.generation_status, "image_status": story.image_status, "audio_status": story.audio_status, "last_error": story.last_error, } ), } except Exception as e: sync_story_status( story, image_status=StoryAssetStatus.FAILED, last_error=str(e), ) await db.commit() logger.warning("sse_image_generation_failed", story_id=story.id, error=str(e)) yield { "event": "image_failed", "data": json.dumps( { "error": str(e), "generation_status": story.generation_status, "image_status": story.image_status, "audio_status": story.audio_status, "last_error": story.last_error, } ), } yield { "event": "complete", "data": json.dumps( { "story_id": story.id, "generation_status": story.generation_status, "image_status": story.image_status, "audio_status": story.audio_status, "last_error": story.last_error, } ), } return EventSourceResponse(event_generator()) @router.post("/storybook/generate", response_model=StorybookResponse) async def generate_storybook_api( request: StorybookRequest, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): """Generate storybook.""" await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW) return await story_service.generate_storybook_service(request, user.id, db) # ==================== Missing Endpoints (Issue #5) ==================== @router.get("/stories", response_model=list[StoryListItem]) async def list_stories( limit: int = 20, offset: int = 0, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): """List stories.""" return await story_service.list_stories(user.id, limit, offset, db) @router.get("/stories/{story_id}", response_model=StoryDetailResponse) async def get_story( story_id: int, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): """Get story detail.""" return await story_service.get_story_detail(story_id, user.id, db) @router.delete("/stories/{story_id}") async def delete_story( story_id: int, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): """Delete story.""" await story_service.delete_story(story_id, user.id, db) return {"message": "Deleted"} @router.post("/image/generate/{story_id}", response_model=StoryImageResponse) async def generate_story_image( story_id: int, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): """Generate cover image for story.""" url = await story_service.generate_story_cover(story_id, user.id, db) story = await story_service.get_story_detail(story_id, user.id, db) return { "image_url": url, "generation_status": story.generation_status, "image_status": story.image_status, "audio_status": story.audio_status, "last_error": story.last_error, } @router.get("/audio/{story_id}") async def get_story_audio( story_id: int, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): """Get story audio (MP3).""" audio_bytes = await story_service.generate_story_audio(story_id, user.id, db) return Response(content=audio_bytes, media_type="audio/mpeg") @router.get("/stories/{story_id}/achievements", response_model=list[AchievementItem]) async def get_story_achievements( story_id: int, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): """Get story achievements.""" return await story_service.get_story_achievements(story_id, user.id, db)