469 lines
15 KiB
Python
469 lines
15 KiB
Python
"""Artifact completion workflows for the generation harness runtime."""
|
||
|
||
from collections.abc import Awaitable, Callable
|
||
|
||
from fastapi import HTTPException
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.core.logging import get_logger
|
||
from app.db.models import Story
|
||
from app.services.harness.artifacts import AssetCompletionResult, asset_result_metadata
|
||
from app.services.harness.control import ExecutionControl
|
||
from app.services.harness.trace import TraceRecorder
|
||
from app.services.story_status import StoryAssetStatus, sync_story_status
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
ImageGenerator = Callable[..., Awaitable[str]]
|
||
TTSGenerator = Callable[..., Awaitable[bytes]]
|
||
AudioCacheExists = Callable[[str], bool]
|
||
AudioCacheReader = Callable[[str], bytes]
|
||
AudioCacheWriter = Callable[[int, bytes], str]
|
||
|
||
|
||
async def complete_cover_image_asset(
|
||
story: Story,
|
||
db: AsyncSession,
|
||
*,
|
||
generate_image_func: ImageGenerator,
|
||
raise_on_failure: bool = False,
|
||
last_error_prefix: str | None = None,
|
||
log_event: str = "cover_asset_generation_failed",
|
||
job=None,
|
||
) -> AssetCompletionResult:
|
||
"""Generate or retry a text story cover through one asset workflow."""
|
||
|
||
if not story.cover_prompt:
|
||
raise HTTPException(status_code=400, detail="Story has no cover prompt")
|
||
|
||
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
|
||
await db.commit()
|
||
await ExecutionControl(db).stop_if_cancel_requested(job=job, story=story)
|
||
await TraceRecorder(db).record_step(
|
||
job=job,
|
||
story_id=story.id,
|
||
event_type="cover_image_started",
|
||
status="running",
|
||
message="Cover image generation started.",
|
||
metadata={"asset": "image", "cover_prompt_present": True},
|
||
)
|
||
|
||
try:
|
||
await ExecutionControl(db).stop_if_cancel_requested(job=job, story=story)
|
||
image_url = await generate_image_func(
|
||
story.cover_prompt,
|
||
db=db,
|
||
user_id=story.user_id,
|
||
generation_job=job,
|
||
story_id=story.id,
|
||
)
|
||
story.image_url = image_url
|
||
sync_story_status(story, image_status=StoryAssetStatus.READY)
|
||
await db.commit()
|
||
result = AssetCompletionResult(
|
||
asset="cover_image",
|
||
status=StoryAssetStatus.READY,
|
||
value=image_url,
|
||
blocks_main_result=raise_on_failure,
|
||
)
|
||
await TraceRecorder(db).record_step(
|
||
job=job,
|
||
story_id=story.id,
|
||
event_type="cover_image_succeeded",
|
||
status="succeeded",
|
||
message="Cover image was generated.",
|
||
metadata=asset_result_metadata(result),
|
||
)
|
||
return result
|
||
except Exception as exc:
|
||
provider_error = str(exc)
|
||
last_error = (
|
||
f"{last_error_prefix}: {provider_error}"
|
||
if last_error_prefix
|
||
else provider_error
|
||
)
|
||
sync_story_status(
|
||
story,
|
||
image_status=StoryAssetStatus.FAILED,
|
||
last_error=last_error,
|
||
)
|
||
await db.commit()
|
||
logger.warning(log_event, story_id=story.id, error=provider_error)
|
||
|
||
result = AssetCompletionResult(
|
||
asset="cover_image",
|
||
status=StoryAssetStatus.FAILED,
|
||
error=provider_error,
|
||
blocks_main_result=raise_on_failure,
|
||
)
|
||
await TraceRecorder(db).record_step(
|
||
job=job,
|
||
story_id=story.id,
|
||
event_type="cover_image_failed",
|
||
status="failed",
|
||
message="Cover image generation failed.",
|
||
metadata=asset_result_metadata(result),
|
||
)
|
||
if raise_on_failure:
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"Image generation failed: {provider_error}",
|
||
) from exc
|
||
|
||
return result
|
||
|
||
|
||
async def read_cached_audio_asset(
|
||
story: Story,
|
||
db: AsyncSession,
|
||
*,
|
||
audio_cache_exists_func: AudioCacheExists,
|
||
read_audio_cache_func: AudioCacheReader,
|
||
) -> bytes | None:
|
||
"""Read cached audio or repair stale audio cache metadata."""
|
||
|
||
if story.audio_path and audio_cache_exists_func(story.audio_path):
|
||
if story.audio_status != StoryAssetStatus.READY.value:
|
||
sync_story_status(story, audio_status=StoryAssetStatus.READY)
|
||
await db.commit()
|
||
return read_audio_cache_func(story.audio_path)
|
||
|
||
if story.audio_path and not audio_cache_exists_func(story.audio_path):
|
||
logger.warning(
|
||
"story_audio_cache_missing",
|
||
story_id=story.id,
|
||
audio_path=story.audio_path,
|
||
)
|
||
story.audio_path = None
|
||
if story.audio_status == StoryAssetStatus.READY.value:
|
||
sync_story_status(story, audio_status=StoryAssetStatus.NOT_REQUESTED)
|
||
await db.commit()
|
||
|
||
return None
|
||
|
||
|
||
async def complete_audio_asset(
|
||
story: Story,
|
||
db: AsyncSession,
|
||
*,
|
||
text_to_speech_func: TTSGenerator,
|
||
audio_cache_exists_func: AudioCacheExists,
|
||
read_audio_cache_func: AudioCacheReader,
|
||
write_story_audio_cache_func: AudioCacheWriter,
|
||
raise_on_failure: bool = True,
|
||
job=None,
|
||
) -> AssetCompletionResult:
|
||
"""Complete TTS audio generation through one asset workflow."""
|
||
|
||
if not story.story_text:
|
||
raise HTTPException(status_code=400, detail="Story has no text")
|
||
|
||
cached_audio = await read_cached_audio_asset(
|
||
story,
|
||
db,
|
||
audio_cache_exists_func=audio_cache_exists_func,
|
||
read_audio_cache_func=read_audio_cache_func,
|
||
)
|
||
if cached_audio is not None:
|
||
result = AssetCompletionResult(
|
||
asset="audio",
|
||
status=StoryAssetStatus.READY,
|
||
value=cached_audio,
|
||
blocks_main_result=raise_on_failure,
|
||
)
|
||
await TraceRecorder(db).record_step(
|
||
job=job,
|
||
story_id=story.id,
|
||
event_type="audio_cache_hit",
|
||
status="succeeded",
|
||
message="Cached story audio was reused.",
|
||
metadata=asset_result_metadata(result),
|
||
)
|
||
return result
|
||
|
||
sync_story_status(story, audio_status=StoryAssetStatus.GENERATING)
|
||
await db.commit()
|
||
await ExecutionControl(db).stop_if_cancel_requested(job=job, story=story)
|
||
await TraceRecorder(db).record_step(
|
||
job=job,
|
||
story_id=story.id,
|
||
event_type="audio_started",
|
||
status="running",
|
||
message="Story audio generation started.",
|
||
metadata={"asset": "audio"},
|
||
)
|
||
|
||
try:
|
||
await ExecutionControl(db).stop_if_cancel_requested(job=job, story=story)
|
||
audio_data = await text_to_speech_func(
|
||
story.story_text,
|
||
db=db,
|
||
user_id=story.user_id,
|
||
generation_job=job,
|
||
story_id=story.id,
|
||
)
|
||
story.audio_path = write_story_audio_cache_func(story.id, audio_data)
|
||
sync_story_status(
|
||
story,
|
||
audio_status=StoryAssetStatus.READY,
|
||
)
|
||
await db.commit()
|
||
result = AssetCompletionResult(
|
||
asset="audio",
|
||
status=StoryAssetStatus.READY,
|
||
value=audio_data,
|
||
blocks_main_result=raise_on_failure,
|
||
)
|
||
await TraceRecorder(db).record_step(
|
||
job=job,
|
||
story_id=story.id,
|
||
event_type="audio_succeeded",
|
||
status="succeeded",
|
||
message="Story audio was generated and cached.",
|
||
metadata=asset_result_metadata(result),
|
||
)
|
||
return result
|
||
except Exception as exc:
|
||
provider_error = str(exc)
|
||
story.audio_path = None
|
||
sync_story_status(
|
||
story,
|
||
audio_status=StoryAssetStatus.FAILED,
|
||
last_error=provider_error,
|
||
)
|
||
await db.commit()
|
||
logger.error("audio_generation_failed", story_id=story.id, error=provider_error)
|
||
|
||
result = AssetCompletionResult(
|
||
asset="audio",
|
||
status=StoryAssetStatus.FAILED,
|
||
error=provider_error,
|
||
blocks_main_result=raise_on_failure,
|
||
)
|
||
await TraceRecorder(db).record_step(
|
||
job=job,
|
||
story_id=story.id,
|
||
event_type="audio_failed",
|
||
status="failed",
|
||
message="Story audio generation failed.",
|
||
metadata=asset_result_metadata(result),
|
||
)
|
||
if raise_on_failure:
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"Audio generation failed: {provider_error}",
|
||
) from exc
|
||
|
||
return result
|
||
|
||
|
||
def get_storybook_pages_data(story: Story) -> list[dict]:
|
||
"""Return mutable storybook page data from the persisted JSON field."""
|
||
|
||
return [dict(page) for page in story.pages or [] if isinstance(page, dict)]
|
||
|
||
|
||
def build_storybook_error_message(
|
||
*,
|
||
cover_failed: bool,
|
||
failed_pages: list[int],
|
||
) -> str | None:
|
||
"""Summarize storybook image generation errors for the latest attempt."""
|
||
|
||
parts: list[str] = []
|
||
if cover_failed:
|
||
parts.append("封面生成失败")
|
||
if failed_pages:
|
||
pages = "、".join(str(page) for page in sorted(failed_pages))
|
||
parts.append(f"第 {pages} 页插图生成失败")
|
||
return ";".join(parts) if parts else None
|
||
|
||
|
||
def resolve_storybook_image_status(
|
||
*,
|
||
generate_images: bool,
|
||
cover_prompt: str | None,
|
||
cover_url: str | None,
|
||
pages_data: list[dict],
|
||
) -> StoryAssetStatus:
|
||
"""Resolve the persisted image status for a storybook."""
|
||
|
||
if not generate_images:
|
||
return StoryAssetStatus.NOT_REQUESTED
|
||
|
||
expected_assets = 0
|
||
ready_assets = 0
|
||
|
||
if cover_prompt or cover_url:
|
||
expected_assets += 1
|
||
if cover_url:
|
||
ready_assets += 1
|
||
|
||
for page in pages_data:
|
||
if not page.get("image_prompt") and not page.get("image_url"):
|
||
continue
|
||
expected_assets += 1
|
||
if page.get("image_url"):
|
||
ready_assets += 1
|
||
|
||
if expected_assets == 0:
|
||
return StoryAssetStatus.NOT_REQUESTED
|
||
|
||
if ready_assets == expected_assets:
|
||
return StoryAssetStatus.READY
|
||
|
||
return StoryAssetStatus.FAILED
|
||
|
||
|
||
async def complete_storybook_image_assets(
|
||
story: Story,
|
||
db: AsyncSession,
|
||
*,
|
||
generate_image_func: ImageGenerator,
|
||
job=None,
|
||
) -> AssetCompletionResult:
|
||
"""Complete missing cover/page images for a persisted storybook."""
|
||
|
||
pages_data = get_storybook_pages_data(story)
|
||
has_image_prompt = bool(story.cover_prompt) or any(
|
||
page.get("image_prompt") for page in pages_data
|
||
)
|
||
if not has_image_prompt:
|
||
raise HTTPException(status_code=400, detail="Storybook has no image prompts")
|
||
|
||
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
|
||
await db.commit()
|
||
await ExecutionControl(db).stop_if_cancel_requested(job=job, story=story)
|
||
await TraceRecorder(db).record_step(
|
||
job=job,
|
||
story_id=story.id,
|
||
event_type="storybook_images_started",
|
||
status="running",
|
||
message="Storybook missing image completion started.",
|
||
metadata={"asset": "image"},
|
||
)
|
||
|
||
cover_failed = False
|
||
failed_pages: list[int] = []
|
||
completed_pages: list[int] = []
|
||
|
||
if story.cover_prompt and not story.image_url:
|
||
await ExecutionControl(db).stop_if_cancel_requested(job=job, story=story)
|
||
try:
|
||
story.image_url = await generate_image_func(
|
||
story.cover_prompt,
|
||
db=db,
|
||
user_id=story.user_id,
|
||
generation_job=job,
|
||
story_id=story.id,
|
||
)
|
||
await TraceRecorder(db).record_step(
|
||
job=job,
|
||
story_id=story.id,
|
||
event_type="storybook_cover_image_succeeded",
|
||
status="succeeded",
|
||
message="Storybook cover image was generated.",
|
||
metadata={"asset": "image", "scope": "cover"},
|
||
)
|
||
except Exception as exc:
|
||
cover_failed = True
|
||
logger.warning(
|
||
"storybook_cover_asset_completion_failed",
|
||
story_id=story.id,
|
||
error=str(exc),
|
||
)
|
||
await TraceRecorder(db).record_step(
|
||
job=job,
|
||
story_id=story.id,
|
||
event_type="storybook_cover_image_failed",
|
||
status="failed",
|
||
message="Storybook cover image generation failed.",
|
||
metadata={"asset": "image", "scope": "cover", "error": str(exc)},
|
||
)
|
||
|
||
for page in pages_data:
|
||
if not page.get("image_prompt") or page.get("image_url"):
|
||
continue
|
||
|
||
await ExecutionControl(db).stop_if_cancel_requested(job=job, story=story)
|
||
try:
|
||
page["image_url"] = await generate_image_func(
|
||
page["image_prompt"],
|
||
db=db,
|
||
user_id=story.user_id,
|
||
generation_job=job,
|
||
story_id=story.id,
|
||
)
|
||
page_number = page.get("page_number")
|
||
if isinstance(page_number, int):
|
||
completed_pages.append(page_number)
|
||
await TraceRecorder(db).record_step(
|
||
job=job,
|
||
story_id=story.id,
|
||
event_type="storybook_page_image_succeeded",
|
||
status="succeeded",
|
||
message="Storybook page image was generated.",
|
||
metadata={"asset": "image", "scope": "page", "page_number": page_number},
|
||
)
|
||
except Exception as exc:
|
||
page_number = page.get("page_number")
|
||
if isinstance(page_number, int):
|
||
failed_pages.append(page_number)
|
||
logger.warning(
|
||
"storybook_page_asset_completion_failed",
|
||
story_id=story.id,
|
||
page=page_number,
|
||
error=str(exc),
|
||
)
|
||
await TraceRecorder(db).record_step(
|
||
job=job,
|
||
story_id=story.id,
|
||
event_type="storybook_page_image_failed",
|
||
status="failed",
|
||
message="Storybook page image generation failed.",
|
||
metadata={
|
||
"asset": "image",
|
||
"scope": "page",
|
||
"page_number": page_number,
|
||
"error": str(exc),
|
||
},
|
||
)
|
||
|
||
story.pages = pages_data
|
||
error_message = build_storybook_error_message(
|
||
cover_failed=cover_failed,
|
||
failed_pages=failed_pages,
|
||
)
|
||
image_status = resolve_storybook_image_status(
|
||
generate_images=True,
|
||
cover_prompt=story.cover_prompt,
|
||
cover_url=story.image_url,
|
||
pages_data=pages_data,
|
||
)
|
||
sync_story_status(
|
||
story,
|
||
image_status=image_status,
|
||
last_error=error_message,
|
||
)
|
||
await db.commit()
|
||
|
||
result = AssetCompletionResult(
|
||
asset="storybook_images",
|
||
status=image_status,
|
||
value=story.image_url,
|
||
error=error_message,
|
||
)
|
||
await TraceRecorder(db).record_step(
|
||
job=job,
|
||
story_id=story.id,
|
||
event_type="storybook_images_completed",
|
||
status="failed" if error_message else "succeeded",
|
||
message="Storybook image completion finished.",
|
||
metadata={
|
||
**asset_result_metadata(result),
|
||
"completed_pages": sorted(completed_pages),
|
||
"failed_pages": sorted(failed_pages),
|
||
},
|
||
)
|
||
return result
|