Some checks failed
Build and Push Docker Images / changes (push) Has been cancelled
Build and Push Docker Images / build-backend (push) Has been cancelled
Build and Push Docker Images / build-frontend (push) Has been cancelled
Build and Push Docker Images / build-admin-frontend (push) Has been cancelled
586 lines
17 KiB
Python
586 lines
17 KiB
Python
"""Story business logic service."""
|
||
|
||
import asyncio
|
||
|
||
from fastapi import HTTPException
|
||
from sqlalchemy import desc, select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy.orm import joinedload
|
||
|
||
from app.core.logging import get_logger
|
||
from app.db.models import ChildProfile, Story, StoryUniverse
|
||
from app.schemas.story_schemas import (
|
||
GenerateRequest,
|
||
StorybookRequest,
|
||
FullStoryResponse,
|
||
StorybookResponse,
|
||
StorybookPageResponse,
|
||
AchievementItem,
|
||
)
|
||
from app.services.audio_storage import (
|
||
audio_cache_exists,
|
||
read_audio_cache,
|
||
write_story_audio_cache,
|
||
)
|
||
from app.services.memory_service import build_enhanced_memory_context
|
||
from app.services.provider_router import (
|
||
generate_story_content,
|
||
generate_image,
|
||
generate_storybook,
|
||
)
|
||
from app.services.story_status import (
|
||
StoryAssetStatus,
|
||
sync_story_status,
|
||
)
|
||
from app.tasks.achievements import extract_story_achievements
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
def _build_storybook_error_message(
|
||
*,
|
||
cover_failed: bool,
|
||
failed_pages: list[int],
|
||
) -> str | None:
|
||
"""Summarize storybook image generation errors for the latest attempt."""
|
||
|
||
parts: list[str] = []
|
||
if cover_failed:
|
||
parts.append("封面生成失败")
|
||
if failed_pages:
|
||
pages = "、".join(str(page) for page in sorted(failed_pages))
|
||
parts.append(f"第 {pages} 页插图生成失败")
|
||
return ";".join(parts) if parts else None
|
||
|
||
|
||
def _resolve_storybook_image_status(
|
||
*,
|
||
generate_images: bool,
|
||
cover_prompt: str | None,
|
||
cover_url: str | None,
|
||
pages_data: list[dict],
|
||
) -> StoryAssetStatus:
|
||
"""Resolve the persisted image status for a storybook."""
|
||
|
||
if not generate_images:
|
||
return StoryAssetStatus.NOT_REQUESTED
|
||
|
||
expected_assets = 0
|
||
ready_assets = 0
|
||
|
||
if cover_prompt or cover_url:
|
||
expected_assets += 1
|
||
if cover_url:
|
||
ready_assets += 1
|
||
|
||
for page in pages_data:
|
||
if not page.get("image_prompt") and not page.get("image_url"):
|
||
continue
|
||
expected_assets += 1
|
||
if page.get("image_url"):
|
||
ready_assets += 1
|
||
|
||
if expected_assets == 0:
|
||
return StoryAssetStatus.NOT_REQUESTED
|
||
|
||
if ready_assets == expected_assets:
|
||
return StoryAssetStatus.READY
|
||
|
||
return StoryAssetStatus.FAILED
|
||
|
||
|
||
async def validate_profile_and_universe(
|
||
profile_id: str | None,
|
||
universe_id: str | None,
|
||
user_id: str,
|
||
db: AsyncSession,
|
||
) -> tuple[str | None, str | None]:
|
||
"""Validate child profile and universe ownership/relationship."""
|
||
if not profile_id and not universe_id:
|
||
return None, None
|
||
|
||
if profile_id:
|
||
result = await db.execute(
|
||
select(ChildProfile).where(
|
||
ChildProfile.id == profile_id,
|
||
ChildProfile.user_id == user_id,
|
||
)
|
||
)
|
||
profile = result.scalar_one_or_none()
|
||
if not profile:
|
||
raise HTTPException(status_code=404, detail="孩子档案不存在")
|
||
|
||
if universe_id:
|
||
result = await db.execute(
|
||
select(StoryUniverse)
|
||
.join(ChildProfile, StoryUniverse.child_profile_id == ChildProfile.id)
|
||
.where(
|
||
StoryUniverse.id == universe_id,
|
||
ChildProfile.user_id == user_id,
|
||
)
|
||
)
|
||
universe = result.scalar_one_or_none()
|
||
if not universe:
|
||
raise HTTPException(status_code=404, detail="故事宇宙不存在")
|
||
if profile_id and universe.child_profile_id != profile_id:
|
||
raise HTTPException(status_code=400, detail="故事宇宙与孩子档案不匹配")
|
||
if not profile_id:
|
||
profile_id = universe.child_profile_id
|
||
|
||
return profile_id, universe_id
|
||
|
||
|
||
async def generate_and_save_story(
|
||
request: GenerateRequest,
|
||
user_id: str,
|
||
db: AsyncSession,
|
||
) -> Story:
|
||
"""Generate generic story content and save to DB."""
|
||
# 1. Validate
|
||
profile_id, universe_id = await validate_profile_and_universe(
|
||
request.child_profile_id, request.universe_id, user_id, db
|
||
)
|
||
|
||
# 2. Build Context
|
||
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
||
|
||
# 3. Generate
|
||
try:
|
||
result = await generate_story_content(
|
||
input_type=request.type,
|
||
data=request.data,
|
||
education_theme=request.education_theme,
|
||
memory_context=memory_context,
|
||
db=db,
|
||
)
|
||
except Exception as exc:
|
||
raise HTTPException(status_code=502, detail="Story generation failed, please try again.") from exc
|
||
|
||
# 4. Save
|
||
story = Story(
|
||
user_id=user_id,
|
||
child_profile_id=profile_id,
|
||
universe_id=universe_id,
|
||
title=result.title,
|
||
story_text=result.story_text,
|
||
cover_prompt=result.cover_prompt_suggestion,
|
||
mode=result.mode,
|
||
)
|
||
sync_story_status(
|
||
story,
|
||
image_status=StoryAssetStatus.NOT_REQUESTED,
|
||
audio_status=StoryAssetStatus.NOT_REQUESTED,
|
||
last_error=None,
|
||
)
|
||
db.add(story)
|
||
await db.commit()
|
||
await db.refresh(story)
|
||
|
||
# 5. Trigger Async Tasks
|
||
if universe_id:
|
||
extract_story_achievements.delay(story.id, universe_id)
|
||
|
||
return story
|
||
|
||
|
||
async def generate_full_story_service(
|
||
request: GenerateRequest,
|
||
user_id: str,
|
||
db: AsyncSession,
|
||
) -> FullStoryResponse:
|
||
"""Generate story with parallel image generation."""
|
||
# 1. Generate text part
|
||
# We can reuse logic or call generate_story_content directly if we want finer control
|
||
# reusing generate_and_save_story to ensure consistency (it handles validation + saving)
|
||
story = await generate_and_save_story(request, user_id, db)
|
||
|
||
# 2. Generate Image (Parallel/Async step in this flow)
|
||
image_url: str | None = None
|
||
errors: dict[str, str | None] = {}
|
||
|
||
if story.cover_prompt:
|
||
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
|
||
await db.commit()
|
||
try:
|
||
image_url = await generate_image(story.cover_prompt, db=db)
|
||
story.image_url = image_url
|
||
sync_story_status(
|
||
story,
|
||
image_status=StoryAssetStatus.READY,
|
||
)
|
||
await db.commit()
|
||
except Exception as exc:
|
||
errors["image"] = str(exc)
|
||
sync_story_status(
|
||
story,
|
||
image_status=StoryAssetStatus.FAILED,
|
||
last_error=str(exc),
|
||
)
|
||
await db.commit()
|
||
logger.warning("image_generation_failed", story_id=story.id, error=str(exc))
|
||
|
||
return FullStoryResponse(
|
||
id=story.id,
|
||
title=story.title,
|
||
story_text=story.story_text,
|
||
cover_prompt=story.cover_prompt,
|
||
image_url=image_url,
|
||
audio_ready=False,
|
||
mode=story.mode,
|
||
errors=errors,
|
||
child_profile_id=story.child_profile_id,
|
||
universe_id=story.universe_id,
|
||
generation_status=story.generation_status,
|
||
image_status=story.image_status,
|
||
audio_status=story.audio_status,
|
||
last_error=story.last_error,
|
||
)
|
||
|
||
|
||
async def generate_storybook_service(
|
||
request: StorybookRequest,
|
||
user_id: str,
|
||
db: AsyncSession,
|
||
) -> StorybookResponse:
|
||
"""Generate storybook with parallel image generation for pages."""
|
||
# 1. Validate
|
||
profile_id, universe_id = await validate_profile_and_universe(
|
||
request.child_profile_id, request.universe_id, user_id, db
|
||
)
|
||
|
||
logger.info(
|
||
"storybook_request",
|
||
user_id=user_id,
|
||
keywords=request.keywords,
|
||
page_count=request.page_count,
|
||
profile_id=profile_id,
|
||
universe_id=universe_id,
|
||
)
|
||
|
||
# 2. Context
|
||
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
||
|
||
# 3. Generate Text Structure
|
||
try:
|
||
storybook = await generate_storybook(
|
||
keywords=request.keywords,
|
||
page_count=request.page_count,
|
||
education_theme=request.education_theme,
|
||
memory_context=memory_context,
|
||
db=db,
|
||
)
|
||
except Exception as e:
|
||
logger.error("storybook_generation_failed", error=str(e))
|
||
raise HTTPException(status_code=500, detail=f"故事书生成失败: {e}")
|
||
|
||
# 4. Parallel Image Generation
|
||
final_cover_url = storybook.cover_url
|
||
cover_failed = False
|
||
failed_pages: list[int] = []
|
||
|
||
if request.generate_images:
|
||
logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages))
|
||
|
||
tasks = []
|
||
|
||
async def _gen_cover():
|
||
nonlocal cover_failed
|
||
|
||
if storybook.cover_prompt and not storybook.cover_url:
|
||
try:
|
||
return await generate_image(storybook.cover_prompt, db=db)
|
||
except Exception as exc:
|
||
cover_failed = True
|
||
logger.warning("cover_gen_failed", error=str(exc))
|
||
return storybook.cover_url
|
||
|
||
tasks.append(_gen_cover())
|
||
|
||
async def _gen_page(page):
|
||
if page.image_prompt and not page.image_url:
|
||
try:
|
||
page.image_url = await generate_image(page.image_prompt, db=db)
|
||
except Exception as exc:
|
||
failed_pages.append(page.page_number)
|
||
logger.warning("page_gen_failed", page=page.page_number, error=str(exc))
|
||
|
||
for page in storybook.pages:
|
||
tasks.append(_gen_page(page))
|
||
|
||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
cover_res = results[0]
|
||
if isinstance(cover_res, str):
|
||
final_cover_url = cover_res
|
||
|
||
logger.info("storybook_parallel_generation_complete")
|
||
|
||
# 5. Save to DB
|
||
pages_data = [
|
||
{
|
||
"page_number": p.page_number,
|
||
"text": p.text,
|
||
"image_prompt": p.image_prompt,
|
||
"image_url": p.image_url,
|
||
}
|
||
for p in storybook.pages
|
||
]
|
||
|
||
story = Story(
|
||
user_id=user_id,
|
||
child_profile_id=profile_id,
|
||
universe_id=universe_id,
|
||
title=storybook.title,
|
||
mode="storybook",
|
||
pages=pages_data,
|
||
story_text=None,
|
||
cover_prompt=storybook.cover_prompt,
|
||
image_url=final_cover_url,
|
||
)
|
||
sync_story_status(
|
||
story,
|
||
image_status=_resolve_storybook_image_status(
|
||
generate_images=request.generate_images,
|
||
cover_prompt=storybook.cover_prompt,
|
||
cover_url=final_cover_url,
|
||
pages_data=pages_data,
|
||
),
|
||
audio_status=StoryAssetStatus.NOT_REQUESTED,
|
||
last_error=_build_storybook_error_message(
|
||
cover_failed=cover_failed,
|
||
failed_pages=failed_pages,
|
||
),
|
||
)
|
||
db.add(story)
|
||
await db.commit()
|
||
await db.refresh(story)
|
||
|
||
if universe_id:
|
||
extract_story_achievements.delay(story.id, universe_id)
|
||
|
||
# 6. Build Response
|
||
response_pages = [
|
||
StorybookPageResponse(
|
||
page_number=p["page_number"],
|
||
text=p["text"],
|
||
image_prompt=p["image_prompt"],
|
||
image_url=p.get("image_url"),
|
||
)
|
||
for p in pages_data
|
||
]
|
||
|
||
return StorybookResponse(
|
||
id=story.id,
|
||
title=storybook.title,
|
||
main_character=storybook.main_character,
|
||
art_style=storybook.art_style,
|
||
pages=response_pages,
|
||
cover_prompt=storybook.cover_prompt,
|
||
cover_url=final_cover_url,
|
||
generation_status=story.generation_status,
|
||
image_status=story.image_status,
|
||
audio_status=story.audio_status,
|
||
last_error=story.last_error,
|
||
)
|
||
|
||
|
||
# ==================== Missing Endpoints Logic (for Issue #5) ====================
|
||
|
||
async def list_stories(
|
||
user_id: str,
|
||
limit: int,
|
||
offset: int,
|
||
db: AsyncSession,
|
||
) -> list[Story]:
|
||
"""List stories for user."""
|
||
result = await db.execute(
|
||
select(Story)
|
||
.where(Story.user_id == user_id)
|
||
.order_by(desc(Story.created_at))
|
||
.offset(offset)
|
||
.limit(limit)
|
||
)
|
||
return result.scalars().all()
|
||
|
||
|
||
async def get_story_detail(
|
||
story_id: int,
|
||
user_id: str,
|
||
db: AsyncSession,
|
||
) -> Story:
|
||
"""Get story detail."""
|
||
result = await db.execute(
|
||
select(Story).where(Story.id == story_id, Story.user_id == user_id)
|
||
)
|
||
story = result.scalar_one_or_none()
|
||
if not story:
|
||
raise HTTPException(status_code=404, detail="Story not found")
|
||
return story
|
||
|
||
|
||
async def delete_story(
|
||
story_id: int,
|
||
user_id: str,
|
||
db: AsyncSession,
|
||
) -> None:
|
||
"""Delete a story."""
|
||
story = await get_story_detail(story_id, user_id, db)
|
||
await db.delete(story)
|
||
await db.commit()
|
||
|
||
|
||
async def create_story_from_result(
|
||
result, # StoryOutput
|
||
user_id: str,
|
||
profile_id: str | None,
|
||
universe_id: str | None,
|
||
db: AsyncSession,
|
||
) -> Story:
|
||
"""Save a generated story to DB (helper for stream endpoint)."""
|
||
story = Story(
|
||
user_id=user_id,
|
||
child_profile_id=profile_id,
|
||
universe_id=universe_id,
|
||
title=result.title,
|
||
story_text=result.story_text,
|
||
cover_prompt=result.cover_prompt_suggestion,
|
||
mode=result.mode,
|
||
)
|
||
sync_story_status(
|
||
story,
|
||
image_status=StoryAssetStatus.NOT_REQUESTED,
|
||
audio_status=StoryAssetStatus.NOT_REQUESTED,
|
||
last_error=None,
|
||
)
|
||
db.add(story)
|
||
await db.commit()
|
||
await db.refresh(story)
|
||
|
||
if universe_id:
|
||
extract_story_achievements.delay(story.id, universe_id)
|
||
|
||
return story
|
||
|
||
|
||
async def generate_story_cover(
|
||
story_id: int,
|
||
user_id: str,
|
||
db: AsyncSession,
|
||
) -> str:
|
||
"""Generate cover image for an existing story."""
|
||
story = await get_story_detail(story_id, user_id, db)
|
||
|
||
if not story.cover_prompt:
|
||
raise HTTPException(status_code=400, detail="Story has no cover prompt")
|
||
|
||
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
|
||
await db.commit()
|
||
|
||
try:
|
||
image_url = await generate_image(story.cover_prompt, db=db)
|
||
story.image_url = image_url
|
||
sync_story_status(
|
||
story,
|
||
image_status=StoryAssetStatus.READY,
|
||
)
|
||
await db.commit()
|
||
return image_url
|
||
except Exception as e:
|
||
sync_story_status(
|
||
story,
|
||
image_status=StoryAssetStatus.FAILED,
|
||
last_error=str(e),
|
||
)
|
||
await db.commit()
|
||
logger.error("cover_generation_failed", story_id=story_id, error=str(e))
|
||
raise HTTPException(status_code=500, detail=f"Image generation failed: {e}")
|
||
|
||
|
||
async def generate_story_audio(
|
||
story_id: int,
|
||
user_id: str,
|
||
db: AsyncSession,
|
||
) -> bytes:
|
||
"""Generate audio for a story."""
|
||
story = await get_story_detail(story_id, user_id, db)
|
||
|
||
if not story.story_text:
|
||
raise HTTPException(status_code=400, detail="Story has no text")
|
||
|
||
if story.audio_path and audio_cache_exists(story.audio_path):
|
||
if story.audio_status != StoryAssetStatus.READY.value:
|
||
sync_story_status(story, audio_status=StoryAssetStatus.READY)
|
||
await db.commit()
|
||
return read_audio_cache(story.audio_path)
|
||
|
||
if story.audio_path and not audio_cache_exists(story.audio_path):
|
||
logger.warning(
|
||
"story_audio_cache_missing",
|
||
story_id=story_id,
|
||
audio_path=story.audio_path,
|
||
)
|
||
story.audio_path = None
|
||
if story.audio_status == StoryAssetStatus.READY.value:
|
||
sync_story_status(story, audio_status=StoryAssetStatus.NOT_REQUESTED)
|
||
await db.commit()
|
||
|
||
from app.services.provider_router import text_to_speech
|
||
|
||
sync_story_status(story, audio_status=StoryAssetStatus.GENERATING)
|
||
await db.commit()
|
||
|
||
try:
|
||
audio_data = await text_to_speech(story.story_text, db=db)
|
||
story.audio_path = write_story_audio_cache(story.id, audio_data)
|
||
sync_story_status(
|
||
story,
|
||
audio_status=StoryAssetStatus.READY,
|
||
)
|
||
await db.commit()
|
||
return audio_data
|
||
except Exception as e:
|
||
story.audio_path = None
|
||
sync_story_status(
|
||
story,
|
||
audio_status=StoryAssetStatus.FAILED,
|
||
last_error=str(e),
|
||
)
|
||
await db.commit()
|
||
logger.error("audio_generation_failed", story_id=story_id, error=str(e))
|
||
raise HTTPException(status_code=500, detail=f"Audio generation failed: {e}")
|
||
|
||
|
||
async def get_story_achievements(
|
||
story_id: int,
|
||
user_id: str,
|
||
db: AsyncSession,
|
||
) -> list[AchievementItem]:
|
||
"""Get achievements unlocked by a specific story."""
|
||
result = await db.execute(
|
||
select(Story)
|
||
.options(joinedload(Story.story_universe))
|
||
.where(Story.id == story_id, Story.user_id == user_id)
|
||
)
|
||
story = result.scalar_one_or_none()
|
||
|
||
if not story:
|
||
raise HTTPException(status_code=404, detail="Story not found")
|
||
|
||
if not story.universe_id or not story.story_universe:
|
||
return []
|
||
|
||
universe = story.story_universe
|
||
if not universe.achievements:
|
||
return []
|
||
|
||
results = []
|
||
for ach in universe.achievements:
|
||
if isinstance(ach, dict) and ach.get("source_story_id") == story_id:
|
||
results.append(AchievementItem(
|
||
type=ach.get("type", "Unknown"),
|
||
description=ach.get("description", ""),
|
||
obtained_at=ach.get("obtained_at")
|
||
))
|
||
return results
|
||
|