Files
dreamweaver/backend/app/services/story_service.py
torin a97a2fe005
Some checks failed
Build and Push Docker Images / changes (push) Has been cancelled
Build and Push Docker Images / build-backend (push) Has been cancelled
Build and Push Docker Images / build-frontend (push) Has been cancelled
Build and Push Docker Images / build-admin-frontend (push) Has been cancelled
feat: persist story generation states and cache audio
2026-04-17 17:14:09 +08:00

586 lines
17 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Story business logic service."""
import asyncio
from fastapi import HTTPException
from sqlalchemy import desc, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.core.logging import get_logger
from app.db.models import ChildProfile, Story, StoryUniverse
from app.schemas.story_schemas import (
GenerateRequest,
StorybookRequest,
FullStoryResponse,
StorybookResponse,
StorybookPageResponse,
AchievementItem,
)
from app.services.audio_storage import (
audio_cache_exists,
read_audio_cache,
write_story_audio_cache,
)
from app.services.memory_service import build_enhanced_memory_context
from app.services.provider_router import (
generate_story_content,
generate_image,
generate_storybook,
)
from app.services.story_status import (
StoryAssetStatus,
sync_story_status,
)
from app.tasks.achievements import extract_story_achievements
logger = get_logger(__name__)
def _build_storybook_error_message(
*,
cover_failed: bool,
failed_pages: list[int],
) -> str | None:
"""Summarize storybook image generation errors for the latest attempt."""
parts: list[str] = []
if cover_failed:
parts.append("封面生成失败")
if failed_pages:
pages = "".join(str(page) for page in sorted(failed_pages))
parts.append(f"{pages} 页插图生成失败")
return "".join(parts) if parts else None
def _resolve_storybook_image_status(
*,
generate_images: bool,
cover_prompt: str | None,
cover_url: str | None,
pages_data: list[dict],
) -> StoryAssetStatus:
"""Resolve the persisted image status for a storybook."""
if not generate_images:
return StoryAssetStatus.NOT_REQUESTED
expected_assets = 0
ready_assets = 0
if cover_prompt or cover_url:
expected_assets += 1
if cover_url:
ready_assets += 1
for page in pages_data:
if not page.get("image_prompt") and not page.get("image_url"):
continue
expected_assets += 1
if page.get("image_url"):
ready_assets += 1
if expected_assets == 0:
return StoryAssetStatus.NOT_REQUESTED
if ready_assets == expected_assets:
return StoryAssetStatus.READY
return StoryAssetStatus.FAILED
async def validate_profile_and_universe(
profile_id: str | None,
universe_id: str | None,
user_id: str,
db: AsyncSession,
) -> tuple[str | None, str | None]:
"""Validate child profile and universe ownership/relationship."""
if not profile_id and not universe_id:
return None, None
if profile_id:
result = await db.execute(
select(ChildProfile).where(
ChildProfile.id == profile_id,
ChildProfile.user_id == user_id,
)
)
profile = result.scalar_one_or_none()
if not profile:
raise HTTPException(status_code=404, detail="孩子档案不存在")
if universe_id:
result = await db.execute(
select(StoryUniverse)
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
.where(
StoryUniverse.id == universe_id,
ChildProfile.user_id == user_id,
)
)
universe = result.scalar_one_or_none()
if not universe:
raise HTTPException(status_code=404, detail="故事宇宙不存在")
if profile_id and universe.child_profile_id != profile_id:
raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配")
if not profile_id:
profile_id = universe.child_profile_id
return profile_id, universe_id
async def generate_and_save_story(
request: GenerateRequest,
user_id: str,
db: AsyncSession,
) -> Story:
"""Generate generic story content and save to DB."""
# 1. Validate
profile_id, universe_id = await validate_profile_and_universe(
request.child_profile_id, request.universe_id, user_id, db
)
# 2. Build Context
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
# 3. Generate
try:
result = await generate_story_content(
input_type=request.type,
data=request.data,
education_theme=request.education_theme,
memory_context=memory_context,
db=db,
)
except Exception as exc:
raise HTTPException(status_code=502, detail="Story generation failed, please try again.") from exc
# 4. Save
story = Story(
user_id=user_id,
child_profile_id=profile_id,
universe_id=universe_id,
title=result.title,
story_text=result.story_text,
cover_prompt=result.cover_prompt_suggestion,
mode=result.mode,
)
sync_story_status(
story,
image_status=StoryAssetStatus.NOT_REQUESTED,
audio_status=StoryAssetStatus.NOT_REQUESTED,
last_error=None,
)
db.add(story)
await db.commit()
await db.refresh(story)
# 5. Trigger Async Tasks
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
return story
async def generate_full_story_service(
request: GenerateRequest,
user_id: str,
db: AsyncSession,
) -> FullStoryResponse:
"""Generate story with parallel image generation."""
# 1. Generate text part
# We can reuse logic or call generate_story_content directly if we want finer control
# reusing generate_and_save_story to ensure consistency (it handles validation + saving)
story = await generate_and_save_story(request, user_id, db)
# 2. Generate Image (Parallel/Async step in this flow)
image_url: str | None = None
errors: dict[str, str | None] = {}
if story.cover_prompt:
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
await db.commit()
try:
image_url = await generate_image(story.cover_prompt, db=db)
story.image_url = image_url
sync_story_status(
story,
image_status=StoryAssetStatus.READY,
)
await db.commit()
except Exception as exc:
errors["image"] = str(exc)
sync_story_status(
story,
image_status=StoryAssetStatus.FAILED,
last_error=str(exc),
)
await db.commit()
logger.warning("image_generation_failed", story_id=story.id, error=str(exc))
return FullStoryResponse(
id=story.id,
title=story.title,
story_text=story.story_text,
cover_prompt=story.cover_prompt,
image_url=image_url,
audio_ready=False,
mode=story.mode,
errors=errors,
child_profile_id=story.child_profile_id,
universe_id=story.universe_id,
generation_status=story.generation_status,
image_status=story.image_status,
audio_status=story.audio_status,
last_error=story.last_error,
)
async def generate_storybook_service(
request: StorybookRequest,
user_id: str,
db: AsyncSession,
) -> StorybookResponse:
"""Generate storybook with parallel image generation for pages."""
# 1. Validate
profile_id, universe_id = await validate_profile_and_universe(
request.child_profile_id, request.universe_id, user_id, db
)
logger.info(
"storybook_request",
user_id=user_id,
keywords=request.keywords,
page_count=request.page_count,
profile_id=profile_id,
universe_id=universe_id,
)
# 2. Context
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
# 3. Generate Text Structure
try:
storybook = await generate_storybook(
keywords=request.keywords,
page_count=request.page_count,
education_theme=request.education_theme,
memory_context=memory_context,
db=db,
)
except Exception as e:
logger.error("storybook_generation_failed", error=str(e))
raise HTTPException(status_code=500, detail=f"故事书生成失败: {e}")
# 4. Parallel Image Generation
final_cover_url = storybook.cover_url
cover_failed = False
failed_pages: list[int] = []
if request.generate_images:
logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages))
tasks = []
async def _gen_cover():
nonlocal cover_failed
if storybook.cover_prompt and not storybook.cover_url:
try:
return await generate_image(storybook.cover_prompt, db=db)
except Exception as exc:
cover_failed = True
logger.warning("cover_gen_failed", error=str(exc))
return storybook.cover_url
tasks.append(_gen_cover())
async def _gen_page(page):
if page.image_prompt and not page.image_url:
try:
page.image_url = await generate_image(page.image_prompt, db=db)
except Exception as exc:
failed_pages.append(page.page_number)
logger.warning("page_gen_failed", page=page.page_number, error=str(exc))
for page in storybook.pages:
tasks.append(_gen_page(page))
results = await asyncio.gather(*tasks, return_exceptions=True)
cover_res = results[0]
if isinstance(cover_res, str):
final_cover_url = cover_res
logger.info("storybook_parallel_generation_complete")
# 5. Save to DB
pages_data = [
{
"page_number": p.page_number,
"text": p.text,
"image_prompt": p.image_prompt,
"image_url": p.image_url,
}
for p in storybook.pages
]
story = Story(
user_id=user_id,
child_profile_id=profile_id,
universe_id=universe_id,
title=storybook.title,
mode="storybook",
pages=pages_data,
story_text=None,
cover_prompt=storybook.cover_prompt,
image_url=final_cover_url,
)
sync_story_status(
story,
image_status=_resolve_storybook_image_status(
generate_images=request.generate_images,
cover_prompt=storybook.cover_prompt,
cover_url=final_cover_url,
pages_data=pages_data,
),
audio_status=StoryAssetStatus.NOT_REQUESTED,
last_error=_build_storybook_error_message(
cover_failed=cover_failed,
failed_pages=failed_pages,
),
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
# 6. Build Response
response_pages = [
StorybookPageResponse(
page_number=p["page_number"],
text=p["text"],
image_prompt=p["image_prompt"],
image_url=p.get("image_url"),
)
for p in pages_data
]
return StorybookResponse(
id=story.id,
title=storybook.title,
main_character=storybook.main_character,
art_style=storybook.art_style,
pages=response_pages,
cover_prompt=storybook.cover_prompt,
cover_url=final_cover_url,
generation_status=story.generation_status,
image_status=story.image_status,
audio_status=story.audio_status,
last_error=story.last_error,
)
# ==================== Missing Endpoints Logic (for Issue #5) ====================
async def list_stories(
user_id: str,
limit: int,
offset: int,
db: AsyncSession,
) -> list[Story]:
"""List stories for user."""
result = await db.execute(
select(Story)
.where(Story.user_id == user_id)
.order_by(desc(Story.created_at))
.offset(offset)
.limit(limit)
)
return result.scalars().all()
async def get_story_detail(
story_id: int,
user_id: str,
db: AsyncSession,
) -> Story:
"""Get story detail."""
result = await db.execute(
select(Story).where(Story.id == story_id, Story.user_id == user_id)
)
story = result.scalar_one_or_none()
if not story:
raise HTTPException(status_code=404, detail="Story not found")
return story
async def delete_story(
story_id: int,
user_id: str,
db: AsyncSession,
) -> None:
"""Delete a story."""
story = await get_story_detail(story_id, user_id, db)
await db.delete(story)
await db.commit()
async def create_story_from_result(
result, # StoryOutput
user_id: str,
profile_id: str | None,
universe_id: str | None,
db: AsyncSession,
) -> Story:
"""Save a generated story to DB (helper for stream endpoint)."""
story = Story(
user_id=user_id,
child_profile_id=profile_id,
universe_id=universe_id,
title=result.title,
story_text=result.story_text,
cover_prompt=result.cover_prompt_suggestion,
mode=result.mode,
)
sync_story_status(
story,
image_status=StoryAssetStatus.NOT_REQUESTED,
audio_status=StoryAssetStatus.NOT_REQUESTED,
last_error=None,
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
return story
async def generate_story_cover(
story_id: int,
user_id: str,
db: AsyncSession,
) -> str:
"""Generate cover image for an existing story."""
story = await get_story_detail(story_id, user_id, db)
if not story.cover_prompt:
raise HTTPException(status_code=400, detail="Story has no cover prompt")
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
await db.commit()
try:
image_url = await generate_image(story.cover_prompt, db=db)
story.image_url = image_url
sync_story_status(
story,
image_status=StoryAssetStatus.READY,
)
await db.commit()
return image_url
except Exception as e:
sync_story_status(
story,
image_status=StoryAssetStatus.FAILED,
last_error=str(e),
)
await db.commit()
logger.error("cover_generation_failed", story_id=story_id, error=str(e))
raise HTTPException(status_code=500, detail=f"Image generation failed: {e}")
async def generate_story_audio(
story_id: int,
user_id: str,
db: AsyncSession,
) -> bytes:
"""Generate audio for a story."""
story = await get_story_detail(story_id, user_id, db)
if not story.story_text:
raise HTTPException(status_code=400, detail="Story has no text")
if story.audio_path and audio_cache_exists(story.audio_path):
if story.audio_status != StoryAssetStatus.READY.value:
sync_story_status(story, audio_status=StoryAssetStatus.READY)
await db.commit()
return read_audio_cache(story.audio_path)
if story.audio_path and not audio_cache_exists(story.audio_path):
logger.warning(
"story_audio_cache_missing",
story_id=story_id,
audio_path=story.audio_path,
)
story.audio_path = None
if story.audio_status == StoryAssetStatus.READY.value:
sync_story_status(story, audio_status=StoryAssetStatus.NOT_REQUESTED)
await db.commit()
from app.services.provider_router import text_to_speech
sync_story_status(story, audio_status=StoryAssetStatus.GENERATING)
await db.commit()
try:
audio_data = await text_to_speech(story.story_text, db=db)
story.audio_path = write_story_audio_cache(story.id, audio_data)
sync_story_status(
story,
audio_status=StoryAssetStatus.READY,
)
await db.commit()
return audio_data
except Exception as e:
story.audio_path = None
sync_story_status(
story,
audio_status=StoryAssetStatus.FAILED,
last_error=str(e),
)
await db.commit()
logger.error("audio_generation_failed", story_id=story_id, error=str(e))
raise HTTPException(status_code=500, detail=f"Audio generation failed: {e}")
async def get_story_achievements(
story_id: int,
user_id: str,
db: AsyncSession,
) -> list[AchievementItem]:
"""Get achievements unlocked by a specific story."""
result = await db.execute(
select(Story)
.options(joinedload(Story.story_universe))
.where(Story.id == story_id, Story.user_id == user_id)
)
story = result.scalar_one_or_none()
if not story:
raise HTTPException(status_code=404, detail="Story not found")
if not story.universe_id or not story.story_universe:
return []
universe = story.story_universe
if not universe.achievements:
return []
results = []
for ach in universe.achievements:
if isinstance(ach, dict) and ach.get("source_story_id") == story_id:
results.append(AchievementItem(
type=ach.get("type", "Unknown"),
description=ach.get("description", ""),
obtained_at=ach.get("obtained_at")
))
return results