refactor: consolidate generation workflow helpers

This commit is contained in:
2026-04-18 13:03:23 +08:00
parent e201fa3358
commit ae7bd79267
4 changed files with 361 additions and 248 deletions

View File

@@ -19,6 +19,8 @@ from app.schemas.story_schemas import (
StorybookRequest,
StorybookResponse,
)
from app.services.adapters.storybook.primary import Storybook
from app.services.adapters.text.models import StoryOutput
from app.services.audio_storage import (
audio_cache_exists,
read_audio_cache,
@@ -89,6 +91,233 @@ def _resolve_storybook_image_status(
return StoryAssetStatus.READY
return StoryAssetStatus.FAILED
async def _prepare_generation_context(
*,
profile_id: str | None,
universe_id: str | None,
user_id: str,
db: AsyncSession,
) -> tuple[str | None, str | None, str]:
"""Validate ownership and build the shared generation context."""
resolved_profile_id, resolved_universe_id = await validate_profile_and_universe(
profile_id, universe_id, user_id, db
)
memory_context = await build_enhanced_memory_context(
resolved_profile_id,
resolved_universe_id,
db,
)
return resolved_profile_id, resolved_universe_id, memory_context
def _trigger_story_postprocessing(story: Story) -> None:
"""Trigger non-blocking post-processing for a persisted story."""
if story.universe_id:
extract_story_achievements.delay(story.id, story.universe_id)
async def _persist_text_story_result(
*,
result: StoryOutput,
user_id: str,
profile_id: str | None,
universe_id: str | None,
db: AsyncSession,
) -> Story:
"""Persist generated text content as the unified story record."""
story = Story(
user_id=user_id,
child_profile_id=profile_id,
universe_id=universe_id,
title=result.title,
story_text=result.story_text,
cover_prompt=result.cover_prompt_suggestion,
mode=result.mode,
)
sync_story_status(
story,
image_status=StoryAssetStatus.NOT_REQUESTED,
audio_status=StoryAssetStatus.NOT_REQUESTED,
last_error=None,
)
db.add(story)
await db.commit()
await db.refresh(story)
_trigger_story_postprocessing(story)
return story
def _storybook_pages_to_data(storybook: Storybook) -> list[dict]:
"""Convert generated storybook pages to the persisted JSON shape."""
return [
{
"page_number": page.page_number,
"text": page.text,
"image_prompt": page.image_prompt,
"image_url": page.image_url,
}
for page in storybook.pages
]
def _storybook_pages_to_response(pages_data: list[dict]) -> list[StorybookPageResponse]:
"""Convert persisted storybook page JSON to API response models."""
return [
StorybookPageResponse(
page_number=page["page_number"],
text=page["text"],
image_prompt=page["image_prompt"],
image_url=page.get("image_url"),
)
for page in pages_data
]
async def _generate_storybook_image_assets(
storybook: Storybook,
db: AsyncSession,
) -> tuple[str | None, bool, list[int]]:
"""Generate storybook cover and page images before persistence."""
final_cover_url = storybook.cover_url
cover_failed = False
failed_pages: list[int] = []
logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages))
async def _gen_cover() -> str | None:
nonlocal cover_failed
if storybook.cover_prompt and not storybook.cover_url:
try:
return await generate_image(storybook.cover_prompt, db=db)
except Exception as exc:
cover_failed = True
logger.warning("cover_gen_failed", error=str(exc))
return storybook.cover_url
async def _gen_page(page) -> None:
if not page.image_prompt or page.image_url:
return
try:
page.image_url = await generate_image(page.image_prompt, db=db)
except Exception as exc:
failed_pages.append(page.page_number)
logger.warning("page_gen_failed", page=page.page_number, error=str(exc))
results = await asyncio.gather(
_gen_cover(),
*(_gen_page(page) for page in storybook.pages),
return_exceptions=True,
)
cover_result = results[0]
if isinstance(cover_result, str):
final_cover_url = cover_result
logger.info("storybook_parallel_generation_complete")
return final_cover_url, cover_failed, failed_pages
async def _persist_storybook_result(
*,
storybook: Storybook,
user_id: str,
profile_id: str | None,
universe_id: str | None,
final_cover_url: str | None,
generate_images: bool,
cover_failed: bool,
failed_pages: list[int],
db: AsyncSession,
) -> tuple[Story, list[dict]]:
"""Persist generated storybook content as the unified story record."""
pages_data = _storybook_pages_to_data(storybook)
story = Story(
user_id=user_id,
child_profile_id=profile_id,
universe_id=universe_id,
title=storybook.title,
mode="storybook",
pages=pages_data,
story_text=None,
cover_prompt=storybook.cover_prompt,
image_url=final_cover_url,
)
sync_story_status(
story,
image_status=_resolve_storybook_image_status(
generate_images=generate_images,
cover_prompt=storybook.cover_prompt,
cover_url=final_cover_url,
pages_data=pages_data,
),
audio_status=StoryAssetStatus.NOT_REQUESTED,
last_error=_build_storybook_error_message(
cover_failed=cover_failed,
failed_pages=failed_pages,
),
)
db.add(story)
await db.commit()
await db.refresh(story)
_trigger_story_postprocessing(story)
return story, pages_data
async def _complete_cover_image_asset(
story: Story,
db: AsyncSession,
*,
raise_on_failure: bool = False,
last_error_prefix: str | None = None,
log_event: str = "cover_asset_generation_failed",
) -> tuple[str | None, str | None]:
"""Generate or retry a text story cover through one asset workflow."""
if not story.cover_prompt:
raise HTTPException(status_code=400, detail="Story has no cover prompt")
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
await db.commit()
try:
image_url = await generate_image(story.cover_prompt, db=db)
story.image_url = image_url
sync_story_status(story, image_status=StoryAssetStatus.READY)
await db.commit()
return image_url, None
except Exception as exc:
provider_error = str(exc)
last_error = (
f"{last_error_prefix}: {provider_error}"
if last_error_prefix
else provider_error
)
sync_story_status(
story,
image_status=StoryAssetStatus.FAILED,
last_error=last_error,
)
await db.commit()
logger.warning(log_event, story_id=story.id, error=provider_error)
if raise_on_failure:
raise HTTPException(
status_code=500,
detail=f"Image generation failed: {provider_error}",
) from exc
return None, provider_error
async def validate_profile_and_universe(
@@ -137,20 +366,18 @@ async def generate_and_save_story(
user_id: str,
db: AsyncSession,
) -> Story:
"""Generate generic story content and save to DB."""
# 1. Validate
profile_id, universe_id = await validate_profile_and_universe(
request.child_profile_id, request.universe_id, user_id, db
)
"""Generate generic story content and save to DB."""
profile_id, universe_id, memory_context = await _prepare_generation_context(
profile_id=request.child_profile_id,
universe_id=request.universe_id,
user_id=user_id,
db=db,
)
# 2. Build Context
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
# 3. Generate
try:
result = await generate_story_content(
input_type=request.type,
data=request.data,
try:
result = await generate_story_content(
input_type=request.type,
data=request.data,
education_theme=request.education_theme,
memory_context=memory_context,
db=db,
@@ -160,32 +387,14 @@ async def generate_and_save_story(
status_code=502,
detail="Story generation failed, please try again.",
) from exc
# 4. Save
story = Story(
return await _persist_text_story_result(
result=result,
user_id=user_id,
child_profile_id=profile_id,
profile_id=profile_id,
universe_id=universe_id,
title=result.title,
story_text=result.story_text,
cover_prompt=result.cover_prompt_suggestion,
mode=result.mode,
db=db,
)
sync_story_status(
story,
image_status=StoryAssetStatus.NOT_REQUESTED,
audio_status=StoryAssetStatus.NOT_REQUESTED,
last_error=None,
)
db.add(story)
await db.commit()
await db.refresh(story)
# 5. Trigger Async Tasks
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
return story
async def generate_full_story_service(
@@ -193,36 +402,19 @@ async def generate_full_story_service(
user_id: str,
db: AsyncSession,
) -> FullStoryResponse:
"""Generate story with parallel image generation."""
# 1. Generate text part
# We can reuse logic or call generate_story_content directly if we want finer control
# reusing generate_and_save_story to ensure consistency (it handles validation + saving)
story = await generate_and_save_story(request, user_id, db)
# 2. Generate Image (Parallel/Async step in this flow)
"""Generate story with parallel image generation."""
story = await generate_and_save_story(request, user_id, db)
image_url: str | None = None
errors: dict[str, str | None] = {}
if story.cover_prompt:
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
await db.commit()
try:
image_url = await generate_image(story.cover_prompt, db=db)
story.image_url = image_url
sync_story_status(
story,
image_status=StoryAssetStatus.READY,
)
await db.commit()
except Exception as exc:
errors["image"] = str(exc)
sync_story_status(
story,
image_status=StoryAssetStatus.FAILED,
last_error=str(exc),
)
await db.commit()
logger.warning("image_generation_failed", story_id=story.id, error=str(exc))
image_url, image_error = await _complete_cover_image_asset(
story,
db,
log_event="image_generation_failed",
)
if image_error:
errors["image"] = image_error
return FullStoryResponse(
id=story.id,
@@ -247,133 +439,60 @@ async def generate_storybook_service(
user_id: str,
db: AsyncSession,
) -> StorybookResponse:
"""Generate storybook with parallel image generation for pages."""
# 1. Validate
profile_id, universe_id = await validate_profile_and_universe(
request.child_profile_id, request.universe_id, user_id, db
)
"""Generate storybook with parallel image generation for pages."""
profile_id, universe_id, memory_context = await _prepare_generation_context(
profile_id=request.child_profile_id,
universe_id=request.universe_id,
user_id=user_id,
db=db,
)
logger.info(
"storybook_request",
user_id=user_id,
logger.info(
"storybook_request",
user_id=user_id,
keywords=request.keywords,
page_count=request.page_count,
profile_id=profile_id,
universe_id=universe_id,
)
# 2. Context
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
# 3. Generate Text Structure
try:
storybook = await generate_storybook(
keywords=request.keywords,
page_count=request.page_count,
universe_id=universe_id,
)
try:
storybook = await generate_storybook(
keywords=request.keywords,
page_count=request.page_count,
education_theme=request.education_theme,
memory_context=memory_context,
db=db,
)
except Exception as e:
logger.error("storybook_generation_failed", error=str(e))
raise HTTPException(status_code=500, detail=f"故事书生成失败: {e}")
# 4. Parallel Image Generation
logger.error("storybook_generation_failed", error=str(e))
raise HTTPException(status_code=500, detail=f"故事书生成失败: {e}")
final_cover_url = storybook.cover_url
cover_failed = False
failed_pages: list[int] = []
if request.generate_images:
logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages))
(
final_cover_url,
cover_failed,
failed_pages,
) = await _generate_storybook_image_assets(storybook, db)
tasks = []
async def _gen_cover():
nonlocal cover_failed
if storybook.cover_prompt and not storybook.cover_url:
try:
return await generate_image(storybook.cover_prompt, db=db)
except Exception as exc:
cover_failed = True
logger.warning("cover_gen_failed", error=str(exc))
return storybook.cover_url
tasks.append(_gen_cover())
async def _gen_page(page):
if page.image_prompt and not page.image_url:
try:
page.image_url = await generate_image(page.image_prompt, db=db)
except Exception as exc:
failed_pages.append(page.page_number)
logger.warning("page_gen_failed", page=page.page_number, error=str(exc))
for page in storybook.pages:
tasks.append(_gen_page(page))
results = await asyncio.gather(*tasks, return_exceptions=True)
cover_res = results[0]
if isinstance(cover_res, str):
final_cover_url = cover_res
logger.info("storybook_parallel_generation_complete")
# 5. Save to DB
pages_data = [
{
"page_number": p.page_number,
"text": p.text,
"image_prompt": p.image_prompt,
"image_url": p.image_url,
}
for p in storybook.pages
]
story = Story(
story, pages_data = await _persist_storybook_result(
storybook=storybook,
user_id=user_id,
child_profile_id=profile_id,
profile_id=profile_id,
universe_id=universe_id,
title=storybook.title,
mode="storybook",
pages=pages_data,
story_text=None,
cover_prompt=storybook.cover_prompt,
image_url=final_cover_url,
final_cover_url=final_cover_url,
generate_images=request.generate_images,
cover_failed=cover_failed,
failed_pages=failed_pages,
db=db,
)
sync_story_status(
story,
image_status=_resolve_storybook_image_status(
generate_images=request.generate_images,
cover_prompt=storybook.cover_prompt,
cover_url=final_cover_url,
pages_data=pages_data,
),
audio_status=StoryAssetStatus.NOT_REQUESTED,
last_error=_build_storybook_error_message(
cover_failed=cover_failed,
failed_pages=failed_pages,
),
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
# 6. Build Response
response_pages = [
StorybookPageResponse(
page_number=p["page_number"],
text=p["text"],
image_prompt=p["image_prompt"],
image_url=p.get("image_url"),
)
for p in pages_data
]
response_pages = _storybook_pages_to_response(pages_data)
return StorybookResponse(
id=story.id,
title=storybook.title,
@@ -523,59 +642,31 @@ async def delete_story(
async def create_story_from_result(
result, # StoryOutput
result: StoryOutput,
user_id: str,
profile_id: str | None,
universe_id: str | None,
db: AsyncSession,
) -> Story:
"""Save a generated story to DB (helper for stream endpoint)."""
story = Story(
user_id=user_id,
child_profile_id=profile_id,
universe_id=universe_id,
title=result.title,
story_text=result.story_text,
cover_prompt=result.cover_prompt_suggestion,
mode=result.mode,
db: AsyncSession,
) -> Story:
"""Save a generated story to DB (helper for stream endpoint)."""
return await _persist_text_story_result(
result=result,
user_id=user_id,
profile_id=profile_id,
universe_id=universe_id,
db=db,
)
sync_story_status(
story,
image_status=StoryAssetStatus.NOT_REQUESTED,
audio_status=StoryAssetStatus.NOT_REQUESTED,
last_error=None,
)
db.add(story)
await db.commit()
await db.refresh(story)
if universe_id:
extract_story_achievements.delay(story.id, universe_id)
return story
async def _retry_cover_image_asset(story: Story, db: AsyncSession) -> None:
"""Retry cover generation for a text story."""
if not story.cover_prompt:
raise HTTPException(status_code=400, detail="Story has no cover prompt")
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
await db.commit()
try:
story.image_url = await generate_image(story.cover_prompt, db=db)
sync_story_status(story, image_status=StoryAssetStatus.READY)
except Exception as exc:
sync_story_status(
story,
image_status=StoryAssetStatus.FAILED,
last_error=f"封面生成失败: {exc}",
)
logger.error("cover_asset_retry_failed", story_id=story.id, error=str(exc))
await db.commit()
await _complete_cover_image_asset(
story,
db,
last_error_prefix="封面生成失败",
log_event="cover_asset_retry_failed",
)
async def _retry_storybook_image_assets(story: Story, db: AsyncSession) -> None:
@@ -679,33 +770,19 @@ async def generate_story_cover(
user_id: str,
db: AsyncSession,
) -> str:
"""Generate cover image for an existing story."""
story = await get_story_detail(story_id, user_id, db)
"""Generate cover image for an existing story."""
story = await get_story_detail(story_id, user_id, db)
if not story.cover_prompt:
raise HTTPException(status_code=400, detail="Story has no cover prompt")
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
await db.commit()
try:
image_url = await generate_image(story.cover_prompt, db=db)
story.image_url = image_url
sync_story_status(
story,
image_status=StoryAssetStatus.READY,
)
await db.commit()
image_url, _ = await _complete_cover_image_asset(
story,
db,
raise_on_failure=True,
log_event="cover_generation_failed",
)
if image_url is not None:
return image_url
except Exception as e:
sync_story_status(
story,
image_status=StoryAssetStatus.FAILED,
last_error=str(e),
)
await db.commit()
logger.error("cover_generation_failed", story_id=story_id, error=str(e))
raise HTTPException(status_code=500, detail=f"Image generation failed: {e}")
raise HTTPException(status_code=500, detail="Image generation failed")
async def generate_story_audio(

View File

@@ -428,6 +428,33 @@ class TestUnifiedGenerations:
assert data["generation_status"] == "narrative_ready"
assert data["image_status"] == "not_requested"
def test_create_story_generation_image_failure(
self,
auth_client: TestClient,
mock_text_provider,
):
with patch("app.services.story_service.generate_image", new_callable=AsyncMock) as mock_img:
mock_img.side_effect = Exception("Image API error")
response = auth_client.post(
"/api/generations",
json={
"output_mode": "story",
"type": "keywords",
"data": "小兔子, 森林",
"generate_images": True,
},
)
assert response.status_code == 200
data = response.json()
assert data["image_url"] is None
assert data["generation_status"] == "degraded_completed"
assert data["image_status"] == "failed"
assert data["audio_status"] == "not_requested"
assert "Image API error" in data["errors"]["image"]
assert "Image API error" in data["last_error"]
def test_create_storybook_generation_success(self, auth_client: TestClient):
with patch(
"app.services.story_service.generate_storybook",