"""Story business logic service.""" import asyncio import json import uuid from typing import Literal from fastapi import HTTPException from sqlalchemy import select, desc from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload from app.core.logging import get_logger from app.db.models import ChildProfile, Story, StoryUniverse from app.schemas.story_schemas import ( GenerateRequest, StorybookRequest, FullStoryResponse, StorybookResponse, StorybookPageResponse, AchievementItem, ) from app.services.memory_service import build_enhanced_memory_context from app.services.provider_router import ( generate_story_content, generate_image, generate_storybook, ) from app.tasks.achievements import extract_story_achievements logger = get_logger(__name__) async def validate_profile_and_universe( profile_id: str | None, universe_id: str | None, user_id: str, db: AsyncSession, ) -> tuple[str | None, str | None]: """Validate child profile and universe ownership/relationship.""" if not profile_id and not universe_id: return None, None if profile_id: result = await db.execute( select(ChildProfile).where( ChildProfile.id == profile_id, ChildProfile.user_id == user_id, ) ) profile = result.scalar_one_or_none() if not profile: raise HTTPException(status_code=404, detail="孩子档案不存在") if universe_id: result = await db.execute( select(StoryUniverse) .join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id) .where( StoryUniverse.id == universe_id, ChildProfile.user_id == user_id, ) ) universe = result.scalar_one_or_none() if not universe: raise HTTPException(status_code=404, detail="故事宇宙不存在") if profile_id and universe.child_profile_id != profile_id: raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配") if not profile_id: profile_id = universe.child_profile_id return profile_id, universe_id async def generate_and_save_story( request: GenerateRequest, user_id: str, db: AsyncSession, ) -> Story: """Generate generic story content and save to DB.""" # 1. Validate profile_id, universe_id = await validate_profile_and_universe( request.child_profile_id, request.universe_id, user_id, db ) # 2. Build Context memory_context = await build_enhanced_memory_context(profile_id, universe_id, db) # 3. Generate try: result = await generate_story_content( input_type=request.type, data=request.data, education_theme=request.education_theme, memory_context=memory_context, db=db, ) except Exception as exc: raise HTTPException(status_code=502, detail="Story generation failed, please try again.") from exc # 4. Save story = Story( user_id=user_id, child_profile_id=profile_id, universe_id=universe_id, title=result.title, story_text=result.story_text, cover_prompt=result.cover_prompt_suggestion, mode=result.mode, ) db.add(story) await db.commit() await db.refresh(story) # 5. Trigger Async Tasks if universe_id: extract_story_achievements.delay(story.id, universe_id) return story async def generate_full_story_service( request: GenerateRequest, user_id: str, db: AsyncSession, ) -> FullStoryResponse: """Generate story with parallel image generation.""" # 1. Generate text part # We can reuse logic or call generate_story_content directly if we want finer control # reusing generate_and_save_story to ensure consistency (it handles validation + saving) story = await generate_and_save_story(request, user_id, db) # 2. Generate Image (Parallel/Async step in this flow) image_url: str | None = None errors: dict[str, str | None] = {} if story.cover_prompt: try: image_url = await generate_image(story.cover_prompt, db=db) story.image_url = image_url await db.commit() except Exception as exc: errors["image"] = str(exc) logger.warning("image_generation_failed", story_id=story.id, error=str(exc)) return FullStoryResponse( id=story.id, title=story.title, story_text=story.story_text, cover_prompt=story.cover_prompt, image_url=image_url, audio_ready=False, mode=story.mode, errors=errors, child_profile_id=story.child_profile_id, universe_id=story.universe_id, ) async def generate_storybook_service( request: StorybookRequest, user_id: str, db: AsyncSession, ) -> StorybookResponse: """Generate storybook with parallel image generation for pages.""" # 1. Validate profile_id, universe_id = await validate_profile_and_universe( request.child_profile_id, request.universe_id, user_id, db ) logger.info( "storybook_request", user_id=user_id, keywords=request.keywords, page_count=request.page_count, profile_id=profile_id, universe_id=universe_id, ) # 2. Context memory_context = await build_enhanced_memory_context(profile_id, universe_id, db) # 3. Generate Text Structure try: storybook = await generate_storybook( keywords=request.keywords, page_count=request.page_count, education_theme=request.education_theme, memory_context=memory_context, db=db, ) except Exception as e: logger.error("storybook_generation_failed", error=str(e)) raise HTTPException(status_code=500, detail=f"故事书生成失败: {e}") # 4. Parallel Image Generation final_cover_url = storybook.cover_url if request.generate_images: logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages)) tasks = [] # Cover Task async def _gen_cover(): if storybook.cover_prompt and not storybook.cover_url: try: return await generate_image(storybook.cover_prompt, db=db) except Exception as e: logger.warning("cover_gen_failed", error=str(e)) return storybook.cover_url tasks.append(_gen_cover()) # Page Tasks async def _gen_page(page): if page.image_prompt and not page.image_url: try: url = await generate_image(page.image_prompt, db=db) page.image_url = url except Exception as e: logger.warning("page_gen_failed", page=page.page_number, error=str(e)) for page in storybook.pages: tasks.append(_gen_page(page)) # Execute results = await asyncio.gather(*tasks, return_exceptions=True) # Update cover result cover_res = results[0] if isinstance(cover_res, str): final_cover_url = cover_res logger.info("storybook_parallel_generation_complete") # 5. Save to DB pages_data = [ { "page_number": p.page_number, "text": p.text, "image_prompt": p.image_prompt, "image_url": p.image_url, } for p in storybook.pages ] story = Story( user_id=user_id, child_profile_id=profile_id, universe_id=universe_id, title=storybook.title, mode="storybook", pages=pages_data, story_text=None, cover_prompt=storybook.cover_prompt, image_url=final_cover_url, ) db.add(story) await db.commit() await db.refresh(story) if universe_id: extract_story_achievements.delay(story.id, universe_id) # 6. Build Response response_pages = [ StorybookPageResponse( page_number=p["page_number"], text=p["text"], image_prompt=p["image_prompt"], image_url=p.get("image_url"), ) for p in pages_data ] return StorybookResponse( id=story.id, title=storybook.title, main_character=storybook.main_character, art_style=storybook.art_style, pages=response_pages, cover_prompt=storybook.cover_prompt, cover_url=final_cover_url, ) # ==================== Missing Endpoints Logic (for Issue #5) ==================== async def list_stories( user_id: str, limit: int, offset: int, db: AsyncSession, ) -> list[Story]: """List stories for user.""" result = await db.execute( select(Story) .where(Story.user_id == user_id) .order_by(desc(Story.created_at)) .offset(offset) .limit(limit) ) return result.scalars().all() async def get_story_detail( story_id: int, user_id: str, db: AsyncSession, ) -> Story: """Get story detail.""" result = await db.execute( select(Story).where(Story.id == story_id, Story.user_id == user_id) ) story = result.scalar_one_or_none() if not story: raise HTTPException(status_code=404, detail="Story not found") return story async def delete_story( story_id: int, user_id: str, db: AsyncSession, ) -> None: """Delete a story.""" story = await get_story_detail(story_id, user_id, db) await db.delete(story) await db.commit() async def create_story_from_result( result, # StoryOutput user_id: str, profile_id: str | None, universe_id: str | None, db: AsyncSession, ) -> Story: """Save a generated story to DB (helper for stream endpoint).""" story = Story( user_id=user_id, child_profile_id=profile_id, universe_id=universe_id, title=result.title, story_text=result.story_text, cover_prompt=result.cover_prompt_suggestion, mode=result.mode, ) db.add(story) await db.commit() await db.refresh(story) if universe_id: extract_story_achievements.delay(story.id, universe_id) return story async def generate_story_cover( story_id: int, user_id: str, db: AsyncSession, ) -> str: """Generate cover image for an existing story.""" story = await get_story_detail(story_id, user_id, db) if not story.cover_prompt: raise HTTPException(status_code=400, detail="Story has no cover prompt") try: image_url = await generate_image(story.cover_prompt, db=db) story.image_url = image_url await db.commit() return image_url except Exception as e: logger.error("cover_generation_failed", story_id=story_id, error=str(e)) raise HTTPException(status_code=500, detail=f"Image generation failed: {e}") async def generate_story_audio( story_id: int, user_id: str, db: AsyncSession, ) -> bytes: """Generate audio for a story.""" story = await get_story_detail(story_id, user_id, db) if not story.story_text: raise HTTPException(status_code=400, detail="Story has no text") # TODO: Check if audio is already cached/saved? # For now, generate on the fly via provider from app.services.provider_router import text_to_speech try: audio_data = await text_to_speech(story.story_text, db=db) return audio_data except Exception as e: logger.error("audio_generation_failed", story_id=story_id, error=str(e)) raise HTTPException(status_code=500, detail=f"Audio generation failed: {e}") async def get_story_achievements( story_id: int, user_id: str, db: AsyncSession, ) -> list[AchievementItem]: """Get achievements unlocked by a specific story.""" result = await db.execute( select(Story) .options(joinedload(Story.story_universe)) .where(Story.id == story_id, Story.user_id == user_id) ) story = result.scalar_one_or_none() if not story: raise HTTPException(status_code=404, detail="Story not found") if not story.universe_id or not story.story_universe: return [] universe = story.story_universe if not universe.achievements: return [] results = [] for ach in universe.achievements: if isinstance(ach, dict) and ach.get("source_story_id") == story_id: results.append(AchievementItem( type=ach.get("type", "Unknown"), description=ach.get("description", ""), obtained_at=ach.get("obtained_at") )) return results