"""Story business logic service.""" import asyncio from fastapi import HTTPException from sqlalchemy import desc, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload from app.core.logging import get_logger from app.db.models import ChildProfile, Story, StoryUniverse from app.schemas.story_schemas import ( GenerateRequest, StorybookRequest, FullStoryResponse, StorybookResponse, StorybookPageResponse, AchievementItem, ) from app.services.audio_storage import ( audio_cache_exists, read_audio_cache, write_story_audio_cache, ) from app.services.memory_service import build_enhanced_memory_context from app.services.provider_router import ( generate_story_content, generate_image, generate_storybook, ) from app.services.story_status import ( StoryAssetStatus, sync_story_status, ) from app.tasks.achievements import extract_story_achievements logger = get_logger(__name__) def _build_storybook_error_message( *, cover_failed: bool, failed_pages: list[int], ) -> str | None: """Summarize storybook image generation errors for the latest attempt.""" parts: list[str] = [] if cover_failed: parts.append("封面生成失败") if failed_pages: pages = "、".join(str(page) for page in sorted(failed_pages)) parts.append(f"第 {pages} 页插图生成失败") return ";".join(parts) if parts else None def _resolve_storybook_image_status( *, generate_images: bool, cover_prompt: str | None, cover_url: str | None, pages_data: list[dict], ) -> StoryAssetStatus: """Resolve the persisted image status for a storybook.""" if not generate_images: return StoryAssetStatus.NOT_REQUESTED expected_assets = 0 ready_assets = 0 if cover_prompt or cover_url: expected_assets += 1 if cover_url: ready_assets += 1 for page in pages_data: if not page.get("image_prompt") and not page.get("image_url"): continue expected_assets += 1 if page.get("image_url"): ready_assets += 1 if expected_assets == 0: return StoryAssetStatus.NOT_REQUESTED if ready_assets == expected_assets: return StoryAssetStatus.READY return StoryAssetStatus.FAILED async def validate_profile_and_universe( profile_id: str | None, universe_id: str | None, user_id: str, db: AsyncSession, ) -> tuple[str | None, str | None]: """Validate child profile and universe ownership/relationship.""" if not profile_id and not universe_id: return None, None if profile_id: result = await db.execute( select(ChildProfile).where( ChildProfile.id == profile_id, ChildProfile.user_id == user_id, ) ) profile = result.scalar_one_or_none() if not profile: raise HTTPException(status_code=404, detail="孩子档案不存在") if universe_id: result = await db.execute( select(StoryUniverse) .join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id) .where( StoryUniverse.id == universe_id, ChildProfile.user_id == user_id, ) ) universe = result.scalar_one_or_none() if not universe: raise HTTPException(status_code=404, detail="故事宇宙不存在") if profile_id and universe.child_profile_id != profile_id: raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配") if not profile_id: profile_id = universe.child_profile_id return profile_id, universe_id async def generate_and_save_story( request: GenerateRequest, user_id: str, db: AsyncSession, ) -> Story: """Generate generic story content and save to DB.""" # 1. Validate profile_id, universe_id = await validate_profile_and_universe( request.child_profile_id, request.universe_id, user_id, db ) # 2. Build Context memory_context = await build_enhanced_memory_context(profile_id, universe_id, db) # 3. Generate try: result = await generate_story_content( input_type=request.type, data=request.data, education_theme=request.education_theme, memory_context=memory_context, db=db, ) except Exception as exc: raise HTTPException(status_code=502, detail="Story generation failed, please try again.") from exc # 4. Save story = Story( user_id=user_id, child_profile_id=profile_id, universe_id=universe_id, title=result.title, story_text=result.story_text, cover_prompt=result.cover_prompt_suggestion, mode=result.mode, ) sync_story_status( story, image_status=StoryAssetStatus.NOT_REQUESTED, audio_status=StoryAssetStatus.NOT_REQUESTED, last_error=None, ) db.add(story) await db.commit() await db.refresh(story) # 5. Trigger Async Tasks if universe_id: extract_story_achievements.delay(story.id, universe_id) return story async def generate_full_story_service( request: GenerateRequest, user_id: str, db: AsyncSession, ) -> FullStoryResponse: """Generate story with parallel image generation.""" # 1. Generate text part # We can reuse logic or call generate_story_content directly if we want finer control # reusing generate_and_save_story to ensure consistency (it handles validation + saving) story = await generate_and_save_story(request, user_id, db) # 2. Generate Image (Parallel/Async step in this flow) image_url: str | None = None errors: dict[str, str | None] = {} if story.cover_prompt: sync_story_status(story, image_status=StoryAssetStatus.GENERATING) await db.commit() try: image_url = await generate_image(story.cover_prompt, db=db) story.image_url = image_url sync_story_status( story, image_status=StoryAssetStatus.READY, ) await db.commit() except Exception as exc: errors["image"] = str(exc) sync_story_status( story, image_status=StoryAssetStatus.FAILED, last_error=str(exc), ) await db.commit() logger.warning("image_generation_failed", story_id=story.id, error=str(exc)) return FullStoryResponse( id=story.id, title=story.title, story_text=story.story_text, cover_prompt=story.cover_prompt, image_url=image_url, audio_ready=False, mode=story.mode, errors=errors, child_profile_id=story.child_profile_id, universe_id=story.universe_id, generation_status=story.generation_status, image_status=story.image_status, audio_status=story.audio_status, last_error=story.last_error, ) async def generate_storybook_service( request: StorybookRequest, user_id: str, db: AsyncSession, ) -> StorybookResponse: """Generate storybook with parallel image generation for pages.""" # 1. Validate profile_id, universe_id = await validate_profile_and_universe( request.child_profile_id, request.universe_id, user_id, db ) logger.info( "storybook_request", user_id=user_id, keywords=request.keywords, page_count=request.page_count, profile_id=profile_id, universe_id=universe_id, ) # 2. Context memory_context = await build_enhanced_memory_context(profile_id, universe_id, db) # 3. Generate Text Structure try: storybook = await generate_storybook( keywords=request.keywords, page_count=request.page_count, education_theme=request.education_theme, memory_context=memory_context, db=db, ) except Exception as e: logger.error("storybook_generation_failed", error=str(e)) raise HTTPException(status_code=500, detail=f"故事书生成失败: {e}") # 4. Parallel Image Generation final_cover_url = storybook.cover_url cover_failed = False failed_pages: list[int] = [] if request.generate_images: logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages)) tasks = [] async def _gen_cover(): nonlocal cover_failed if storybook.cover_prompt and not storybook.cover_url: try: return await generate_image(storybook.cover_prompt, db=db) except Exception as exc: cover_failed = True logger.warning("cover_gen_failed", error=str(exc)) return storybook.cover_url tasks.append(_gen_cover()) async def _gen_page(page): if page.image_prompt and not page.image_url: try: page.image_url = await generate_image(page.image_prompt, db=db) except Exception as exc: failed_pages.append(page.page_number) logger.warning("page_gen_failed", page=page.page_number, error=str(exc)) for page in storybook.pages: tasks.append(_gen_page(page)) results = await asyncio.gather(*tasks, return_exceptions=True) cover_res = results[0] if isinstance(cover_res, str): final_cover_url = cover_res logger.info("storybook_parallel_generation_complete") # 5. Save to DB pages_data = [ { "page_number": p.page_number, "text": p.text, "image_prompt": p.image_prompt, "image_url": p.image_url, } for p in storybook.pages ] story = Story( user_id=user_id, child_profile_id=profile_id, universe_id=universe_id, title=storybook.title, mode="storybook", pages=pages_data, story_text=None, cover_prompt=storybook.cover_prompt, image_url=final_cover_url, ) sync_story_status( story, image_status=_resolve_storybook_image_status( generate_images=request.generate_images, cover_prompt=storybook.cover_prompt, cover_url=final_cover_url, pages_data=pages_data, ), audio_status=StoryAssetStatus.NOT_REQUESTED, last_error=_build_storybook_error_message( cover_failed=cover_failed, failed_pages=failed_pages, ), ) db.add(story) await db.commit() await db.refresh(story) if universe_id: extract_story_achievements.delay(story.id, universe_id) # 6. Build Response response_pages = [ StorybookPageResponse( page_number=p["page_number"], text=p["text"], image_prompt=p["image_prompt"], image_url=p.get("image_url"), ) for p in pages_data ] return StorybookResponse( id=story.id, title=storybook.title, main_character=storybook.main_character, art_style=storybook.art_style, pages=response_pages, cover_prompt=storybook.cover_prompt, cover_url=final_cover_url, generation_status=story.generation_status, image_status=story.image_status, audio_status=story.audio_status, last_error=story.last_error, ) # ==================== Missing Endpoints Logic (for Issue #5) ==================== async def list_stories( user_id: str, limit: int, offset: int, db: AsyncSession, ) -> list[Story]: """List stories for user.""" result = await db.execute( select(Story) .where(Story.user_id == user_id) .order_by(desc(Story.created_at)) .offset(offset) .limit(limit) ) return result.scalars().all() async def get_story_detail( story_id: int, user_id: str, db: AsyncSession, ) -> Story: """Get story detail.""" result = await db.execute( select(Story).where(Story.id == story_id, Story.user_id == user_id) ) story = result.scalar_one_or_none() if not story: raise HTTPException(status_code=404, detail="Story not found") return story async def delete_story( story_id: int, user_id: str, db: AsyncSession, ) -> None: """Delete a story.""" story = await get_story_detail(story_id, user_id, db) await db.delete(story) await db.commit() async def create_story_from_result( result, # StoryOutput user_id: str, profile_id: str | None, universe_id: str | None, db: AsyncSession, ) -> Story: """Save a generated story to DB (helper for stream endpoint).""" story = Story( user_id=user_id, child_profile_id=profile_id, universe_id=universe_id, title=result.title, story_text=result.story_text, cover_prompt=result.cover_prompt_suggestion, mode=result.mode, ) sync_story_status( story, image_status=StoryAssetStatus.NOT_REQUESTED, audio_status=StoryAssetStatus.NOT_REQUESTED, last_error=None, ) db.add(story) await db.commit() await db.refresh(story) if universe_id: extract_story_achievements.delay(story.id, universe_id) return story async def generate_story_cover( story_id: int, user_id: str, db: AsyncSession, ) -> str: """Generate cover image for an existing story.""" story = await get_story_detail(story_id, user_id, db) if not story.cover_prompt: raise HTTPException(status_code=400, detail="Story has no cover prompt") sync_story_status(story, image_status=StoryAssetStatus.GENERATING) await db.commit() try: image_url = await generate_image(story.cover_prompt, db=db) story.image_url = image_url sync_story_status( story, image_status=StoryAssetStatus.READY, ) await db.commit() return image_url except Exception as e: sync_story_status( story, image_status=StoryAssetStatus.FAILED, last_error=str(e), ) await db.commit() logger.error("cover_generation_failed", story_id=story_id, error=str(e)) raise HTTPException(status_code=500, detail=f"Image generation failed: {e}") async def generate_story_audio( story_id: int, user_id: str, db: AsyncSession, ) -> bytes: """Generate audio for a story.""" story = await get_story_detail(story_id, user_id, db) if not story.story_text: raise HTTPException(status_code=400, detail="Story has no text") if story.audio_path and audio_cache_exists(story.audio_path): if story.audio_status != StoryAssetStatus.READY.value: sync_story_status(story, audio_status=StoryAssetStatus.READY) await db.commit() return read_audio_cache(story.audio_path) if story.audio_path and not audio_cache_exists(story.audio_path): logger.warning( "story_audio_cache_missing", story_id=story_id, audio_path=story.audio_path, ) story.audio_path = None if story.audio_status == StoryAssetStatus.READY.value: sync_story_status(story, audio_status=StoryAssetStatus.NOT_REQUESTED) await db.commit() from app.services.provider_router import text_to_speech sync_story_status(story, audio_status=StoryAssetStatus.GENERATING) await db.commit() try: audio_data = await text_to_speech(story.story_text, db=db) story.audio_path = write_story_audio_cache(story.id, audio_data) sync_story_status( story, audio_status=StoryAssetStatus.READY, ) await db.commit() return audio_data except Exception as e: story.audio_path = None sync_story_status( story, audio_status=StoryAssetStatus.FAILED, last_error=str(e), ) await db.commit() logger.error("audio_generation_failed", story_id=story_id, error=str(e)) raise HTTPException(status_code=500, detail=f"Audio generation failed: {e}") async def get_story_achievements( story_id: int, user_id: str, db: AsyncSession, ) -> list[AchievementItem]: """Get achievements unlocked by a specific story.""" result = await db.execute( select(Story) .options(joinedload(Story.story_universe)) .where(Story.id == story_id, Story.user_id == user_id) ) story = result.scalar_one_or_none() if not story: raise HTTPException(status_code=404, detail="Story not found") if not story.universe_id or not story.story_universe: return [] universe = story.story_universe if not universe.achievements: return [] results = [] for ach in universe.achievements: if isinstance(ach, dict) and ach.get("source_story_id") == story_id: results.append(AchievementItem( type=ach.get("type", "Unknown"), description=ach.get("description", ""), obtained_at=ach.get("obtained_at") )) return results