refactor: unify asset completion workflows

This commit is contained in:
2026-04-18 13:20:05 +08:00
parent ae7bd79267
commit f1cbd202ab
4 changed files with 205 additions and 116 deletions

View File

@@ -318,6 +318,148 @@ async def _complete_cover_image_asset(
) from exc
return None, provider_error
def _get_storybook_pages_data(story: Story) -> list[dict]:
"""Return mutable storybook page data from the persisted JSON field."""
return [dict(page) for page in story.pages or [] if isinstance(page, dict)]
async def _complete_storybook_image_assets(
story: Story,
db: AsyncSession,
) -> None:
"""Complete missing cover/page images for a persisted storybook."""
pages_data = _get_storybook_pages_data(story)
has_image_prompt = bool(story.cover_prompt) or any(
page.get("image_prompt") for page in pages_data
)
if not has_image_prompt:
raise HTTPException(status_code=400, detail="Storybook has no image prompts")
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
await db.commit()
cover_failed = False
failed_pages: list[int] = []
if story.cover_prompt and not story.image_url:
try:
story.image_url = await generate_image(story.cover_prompt, db=db)
except Exception as exc:
cover_failed = True
logger.warning(
"storybook_cover_asset_completion_failed",
story_id=story.id,
error=str(exc),
)
for page in pages_data:
if not page.get("image_prompt") or page.get("image_url"):
continue
try:
page["image_url"] = await generate_image(page["image_prompt"], db=db)
except Exception as exc:
page_number = page.get("page_number")
if isinstance(page_number, int):
failed_pages.append(page_number)
logger.warning(
"storybook_page_asset_completion_failed",
story_id=story.id,
page=page_number,
error=str(exc),
)
story.pages = pages_data
sync_story_status(
story,
image_status=_resolve_storybook_image_status(
generate_images=True,
cover_prompt=story.cover_prompt,
cover_url=story.image_url,
pages_data=pages_data,
),
last_error=_build_storybook_error_message(
cover_failed=cover_failed,
failed_pages=failed_pages,
),
)
await db.commit()
async def _read_cached_audio_asset(story: Story, db: AsyncSession) -> bytes | None:
"""Read cached audio or repair stale audio cache metadata."""
if story.audio_path and audio_cache_exists(story.audio_path):
if story.audio_status != StoryAssetStatus.READY.value:
sync_story_status(story, audio_status=StoryAssetStatus.READY)
await db.commit()
return read_audio_cache(story.audio_path)
if story.audio_path and not audio_cache_exists(story.audio_path):
logger.warning(
"story_audio_cache_missing",
story_id=story.id,
audio_path=story.audio_path,
)
story.audio_path = None
if story.audio_status == StoryAssetStatus.READY.value:
sync_story_status(story, audio_status=StoryAssetStatus.NOT_REQUESTED)
await db.commit()
return None
async def _complete_audio_asset(
story: Story,
db: AsyncSession,
*,
raise_on_failure: bool = True,
) -> bytes | None:
"""Complete TTS audio generation through one asset workflow."""
if not story.story_text:
raise HTTPException(status_code=400, detail="Story has no text")
cached_audio = await _read_cached_audio_asset(story, db)
if cached_audio is not None:
return cached_audio
from app.services.provider_router import text_to_speech
sync_story_status(story, audio_status=StoryAssetStatus.GENERATING)
await db.commit()
try:
audio_data = await text_to_speech(story.story_text, db=db)
story.audio_path = write_story_audio_cache(story.id, audio_data)
sync_story_status(
story,
audio_status=StoryAssetStatus.READY,
)
await db.commit()
return audio_data
except Exception as exc:
provider_error = str(exc)
story.audio_path = None
sync_story_status(
story,
audio_status=StoryAssetStatus.FAILED,
last_error=provider_error,
)
await db.commit()
logger.error("audio_generation_failed", story_id=story.id, error=provider_error)
if raise_on_failure:
raise HTTPException(
status_code=500,
detail=f"Audio generation failed: {provider_error}",
) from exc
return None
async def validate_profile_and_universe(
@@ -672,74 +814,13 @@ async def _retry_cover_image_asset(story: Story, db: AsyncSession) -> None:
async def _retry_storybook_image_assets(story: Story, db: AsyncSession) -> None:
"""Retry missing storybook cover/page images."""
pages_data = [dict(page) for page in story.pages or [] if isinstance(page, dict)]
has_image_prompt = bool(story.cover_prompt) or any(
page.get("image_prompt") for page in pages_data
)
if not has_image_prompt:
raise HTTPException(status_code=400, detail="Storybook has no image prompts")
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
await db.commit()
cover_failed = False
failed_pages: list[int] = []
if story.cover_prompt and not story.image_url:
try:
story.image_url = await generate_image(story.cover_prompt, db=db)
except Exception as exc:
cover_failed = True
logger.warning(
"storybook_cover_asset_retry_failed",
story_id=story.id,
error=str(exc),
)
for page in pages_data:
if not page.get("image_prompt") or page.get("image_url"):
continue
try:
page["image_url"] = await generate_image(page["image_prompt"], db=db)
except Exception as exc:
page_number = page.get("page_number")
if isinstance(page_number, int):
failed_pages.append(page_number)
logger.warning(
"storybook_page_asset_retry_failed",
story_id=story.id,
page=page_number,
error=str(exc),
)
story.pages = pages_data
sync_story_status(
story,
image_status=_resolve_storybook_image_status(
generate_images=True,
cover_prompt=story.cover_prompt,
cover_url=story.image_url,
pages_data=pages_data,
),
last_error=_build_storybook_error_message(
cover_failed=cover_failed,
failed_pages=failed_pages,
),
)
await db.commit()
await _complete_storybook_image_assets(story, db)
async def _retry_audio_asset(story_id: int, user_id: str, db: AsyncSession) -> None:
async def _retry_audio_asset(story: Story, db: AsyncSession) -> None:
"""Retry audio generation while preserving persisted status on provider failure."""
try:
await generate_story_audio(story_id, user_id, db)
except HTTPException as exc:
if exc.status_code >= 500:
logger.warning("audio_asset_retry_failed", story_id=story_id, error=exc.detail)
return
raise
await _complete_audio_asset(story, db, raise_on_failure=False)
async def retry_story_assets(
@@ -760,7 +841,7 @@ async def retry_story_assets(
await _retry_cover_image_asset(story, db)
if "audio" in requested_assets:
await _retry_audio_asset(story_id, user_id, db)
await _retry_audio_asset(story, db)
return await get_story_detail(story_id, user_id, db)
@@ -790,53 +871,14 @@ async def generate_story_audio(
user_id: str,
db: AsyncSession,
) -> bytes:
"""Generate audio for a story."""
story = await get_story_detail(story_id, user_id, db)
"""Generate audio for a story."""
story = await get_story_detail(story_id, user_id, db)
if not story.story_text:
raise HTTPException(status_code=400, detail="Story has no text")
if story.audio_path and audio_cache_exists(story.audio_path):
if story.audio_status != StoryAssetStatus.READY.value:
sync_story_status(story, audio_status=StoryAssetStatus.READY)
await db.commit()
return read_audio_cache(story.audio_path)
if story.audio_path and not audio_cache_exists(story.audio_path):
logger.warning(
"story_audio_cache_missing",
story_id=story_id,
audio_path=story.audio_path,
)
story.audio_path = None
if story.audio_status == StoryAssetStatus.READY.value:
sync_story_status(story, audio_status=StoryAssetStatus.NOT_REQUESTED)
await db.commit()
from app.services.provider_router import text_to_speech
sync_story_status(story, audio_status=StoryAssetStatus.GENERATING)
await db.commit()
try:
audio_data = await text_to_speech(story.story_text, db=db)
story.audio_path = write_story_audio_cache(story.id, audio_data)
sync_story_status(
story,
audio_status=StoryAssetStatus.READY,
)
await db.commit()
audio_data = await _complete_audio_asset(story, db, raise_on_failure=True)
if audio_data is not None:
return audio_data
except Exception as e:
story.audio_path = None
sync_story_status(
story,
audio_status=StoryAssetStatus.FAILED,
last_error=str(e),
)
await db.commit()
logger.error("audio_generation_failed", story_id=story_id, error=str(e))
raise HTTPException(status_code=500, detail=f"Audio generation failed: {e}")
raise HTTPException(status_code=500, detail="Audio generation failed")
async def get_story_achievements(

View File

@@ -593,6 +593,51 @@ class TestAssetRetry:
assert data["pages"][1]["image_url"] == "https://example.com/retried-page.png"
mock_image.assert_awaited_once()
def test_retry_audio_success(
self,
auth_client: TestClient,
test_story,
mock_tts_provider,
):
response = auth_client.post(
f"/api/stories/{test_story.id}/assets/retry",
json={"assets": ["audio"]},
)
assert response.status_code == 200
data = response.json()
assert data["generation_status"] == "completed"
assert data["image_status"] == "not_requested"
assert data["audio_status"] == "ready"
assert data["last_error"] is None
mock_tts_provider.assert_awaited_once()
cached_audio_path = Path(settings.story_audio_cache_dir) / f"story-{test_story.id}.mp3"
assert cached_audio_path.is_file()
def test_retry_audio_failure_updates_status_without_blocking_response(
self,
auth_client: TestClient,
test_story,
):
with patch(
"app.services.provider_router.text_to_speech",
new_callable=AsyncMock,
) as mock_tts:
mock_tts.side_effect = Exception("TTS provider timeout")
response = auth_client.post(
f"/api/stories/{test_story.id}/assets/retry",
json={"assets": ["audio"]},
)
assert response.status_code == 200
data = response.json()
assert data["generation_status"] == "degraded_completed"
assert data["image_status"] == "not_requested"
assert data["audio_status"] == "failed"
assert "TTS provider timeout" in data["last_error"]
def test_retry_audio_on_storybook_is_rejected(
self,
auth_client: TestClient,