"""Story related APIs.""" import json import uuid from typing import AsyncGenerator from fastapi import APIRouter, Depends, Response from sqlalchemy.ext.asyncio import AsyncSession from sse_starlette.sse import EventSourceResponse from app.core.deps import require_user from app.core.logging import get_logger from app.core.rate_limiter import check_rate_limit from app.db.database import get_db from app.db.models import User from app.schemas.story_schemas import ( AchievementItem, FullStoryResponse, GenerateRequest, GenerationRequest, GenerationResponse, StoryAssetRetryRequest, StorybookRequest, StorybookResponse, StoryDetailResponse, StoryImageResponse, StoryListItem, StoryResponse, ) from app.services import story_service from app.services.memory_service import build_enhanced_memory_context from app.services.provider_router import ( generate_image, generate_story_content, ) from app.services.story_status import StoryAssetStatus, sync_story_status logger = get_logger(__name__) router = APIRouter() RATE_LIMIT_WINDOW = 60 # seconds RATE_LIMIT_REQUESTS = 10 def _legacy_generation_headers(successor: str) -> dict[str, str]: return { "Deprecation": "true", "Link": f"<{successor}>; rel=\"successor-version\"", "X-DreamWeaver-Successor-Endpoint": successor, } def _mark_legacy_generation_endpoint(response: Response, successor: str) -> None: response.headers.update(_legacy_generation_headers(successor)) @router.post("/generations", response_model=GenerationResponse) async def create_generation( request: GenerationRequest, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): """Create a story or storybook through the unified generation workflow.""" await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW) return await story_service.generate_generation_service(request, user.id, db) @router.get("/generations/{story_id}", response_model=StoryDetailResponse) async def get_generation( story_id: int, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): """Get a generated story/storybook through the unified generation API.""" return await story_service.get_story_detail(story_id, user.id, db) @router.post("/generations/{story_id}/retry-assets", response_model=StoryDetailResponse) async def retry_generation_assets( story_id: int, payload: StoryAssetRetryRequest, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): """Retry generated assets through the unified generation API.""" return await story_service.retry_story_assets(story_id, user.id, payload.assets, db) @router.post("/stories/generate", response_model=StoryResponse) async def generate_story( request: GenerateRequest, response: Response, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): """Generate or enhance a story.""" _mark_legacy_generation_endpoint(response, "/api/generations") await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW) return await story_service.generate_and_save_story(request, user.id, db) @router.post("/stories/generate/full", response_model=FullStoryResponse) async def generate_story_full( request: GenerateRequest, response: Response, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): """Generate complete story (story + parallel image/audio generation).""" _mark_legacy_generation_endpoint(response, "/api/generations") await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW) return await story_service.generate_full_story_service(request, user.id, db) @router.post("/stories/generate/stream") async def generate_story_stream( request: GenerateRequest, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): """流式生成故事(SSE)。""" await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW) # Validation profile_id, universe_id = await story_service.validate_profile_and_universe( request.child_profile_id, request.universe_id, user.id, db ) # Build Context memory_context = await build_enhanced_memory_context(profile_id, universe_id, db) async def event_generator() -> AsyncGenerator[dict, None]: story_id = str(uuid.uuid4()) yield {"event": "started", "data": json.dumps({"story_id": story_id})} # Step 1: Generate Content try: result = await generate_story_content( input_type=request.type, data=request.data, education_theme=request.education_theme, memory_context=memory_context, db=db, ) except Exception as e: logger.error("sse_story_generation_failed", error=str(e)) yield {"event": "story_failed", "data": json.dumps({"error": str(e)})} return # Save Story story = await story_service.create_story_from_result( result, user.id, profile_id, universe_id, db ) yield { "event": "story_ready", "data": json.dumps({ "id": story.id, "title": story.title, "content": story.story_text, "cover_prompt": story.cover_prompt, "mode": story.mode, "child_profile_id": story.child_profile_id, "universe_id": story.universe_id, "generation_status": story.generation_status, "image_status": story.image_status, "audio_status": story.audio_status, "last_error": story.last_error, }), } # Step 2: Generate Image if story.cover_prompt: sync_story_status(story, image_status=StoryAssetStatus.GENERATING) await db.commit() try: # Direct call to provider router's generate_image, sharing db session image_url = await generate_image(story.cover_prompt, db=db) story.image_url = image_url sync_story_status( story, image_status=StoryAssetStatus.READY, ) await db.commit() yield { "event": "image_ready", "data": json.dumps( { "image_url": image_url, "generation_status": story.generation_status, "image_status": story.image_status, "audio_status": story.audio_status, "last_error": story.last_error, } ), } except Exception as e: sync_story_status( story, image_status=StoryAssetStatus.FAILED, last_error=str(e), ) await db.commit() logger.warning("sse_image_generation_failed", story_id=story.id, error=str(e)) yield { "event": "image_failed", "data": json.dumps( { "error": str(e), "generation_status": story.generation_status, "image_status": story.image_status, "audio_status": story.audio_status, "last_error": story.last_error, } ), } yield { "event": "complete", "data": json.dumps( { "story_id": story.id, "generation_status": story.generation_status, "image_status": story.image_status, "audio_status": story.audio_status, "last_error": story.last_error, } ), } return EventSourceResponse( event_generator(), headers=_legacy_generation_headers("/api/generations"), ) @router.post("/storybook/generate", response_model=StorybookResponse) async def generate_storybook_api( request: StorybookRequest, response: Response, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): """Generate storybook.""" _mark_legacy_generation_endpoint(response, "/api/generations") await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW) return await story_service.generate_storybook_service(request, user.id, db) @router.get("/stories", response_model=list[StoryListItem]) async def list_stories( limit: int = 20, offset: int = 0, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): """List stories.""" return await story_service.list_stories(user.id, limit, offset, db) @router.get("/stories/{story_id}", response_model=StoryDetailResponse) async def get_story( story_id: int, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): """Get story detail.""" return await story_service.get_story_detail(story_id, user.id, db) @router.delete("/stories/{story_id}") async def delete_story( story_id: int, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): """Delete story.""" await story_service.delete_story(story_id, user.id, db) return {"message": "Deleted"} @router.post("/image/generate/{story_id}", response_model=StoryImageResponse) async def generate_story_image( story_id: int, response: Response, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): """Generate cover image for story.""" _mark_legacy_generation_endpoint( response, f"/api/generations/{story_id}/retry-assets", ) url = await story_service.generate_story_cover(story_id, user.id, db) story = await story_service.get_story_detail(story_id, user.id, db) return { "image_url": url, "generation_status": story.generation_status, "image_status": story.image_status, "audio_status": story.audio_status, "last_error": story.last_error, "retryable_assets": story.retryable_assets, } @router.post("/stories/{story_id}/assets/retry", response_model=StoryDetailResponse) async def retry_story_assets( story_id: int, payload: StoryAssetRetryRequest, response: Response, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): """Retry selected generated assets for a story.""" _mark_legacy_generation_endpoint( response, f"/api/generations/{story_id}/retry-assets", ) return await story_service.retry_story_assets(story_id, user.id, payload.assets, db) @router.get("/audio/{story_id}") async def get_story_audio( story_id: int, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): """Get story audio (MP3).""" audio_bytes = await story_service.generate_story_audio(story_id, user.id, db) return Response(content=audio_bytes, media_type="audio/mpeg") @router.get("/stories/{story_id}/achievements", response_model=list[AchievementItem]) async def get_story_achievements( story_id: int, user: User = Depends(require_user), db: AsyncSession = Depends(get_db), ): """Get story achievements.""" return await story_service.get_story_achievements(story_id, user.id, db)