Files
dreamweaver/backend/app/services/harness/asset_workflows.py
2026-06-21 22:31:38 +08:00

469 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Artifact completion workflows for the generation harness runtime."""
from collections.abc import Awaitable, Callable
from fastapi import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.models import Story
from app.services.harness.artifacts import AssetCompletionResult, asset_result_metadata
from app.services.harness.control import ExecutionControl
from app.services.harness.trace import TraceRecorder
from app.services.story_status import StoryAssetStatus, sync_story_status
logger = get_logger(__name__)
ImageGenerator = Callable[..., Awaitable[str]]
TTSGenerator = Callable[..., Awaitable[bytes]]
AudioCacheExists = Callable[[str], bool]
AudioCacheReader = Callable[[str], bytes]
AudioCacheWriter = Callable[[int, bytes], str]
async def complete_cover_image_asset(
story: Story,
db: AsyncSession,
*,
generate_image_func: ImageGenerator,
raise_on_failure: bool = False,
last_error_prefix: str | None = None,
log_event: str = "cover_asset_generation_failed",
job=None,
) -> AssetCompletionResult:
"""Generate or retry a text story cover through one asset workflow."""
if not story.cover_prompt:
raise HTTPException(status_code=400, detail="Story has no cover prompt")
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
await db.commit()
await ExecutionControl(db).stop_if_cancel_requested(job=job, story=story)
await TraceRecorder(db).record_step(
job=job,
story_id=story.id,
event_type="cover_image_started",
status="running",
message="Cover image generation started.",
metadata={"asset": "image", "cover_prompt_present": True},
)
try:
await ExecutionControl(db).stop_if_cancel_requested(job=job, story=story)
image_url = await generate_image_func(
story.cover_prompt,
db=db,
user_id=story.user_id,
generation_job=job,
story_id=story.id,
)
story.image_url = image_url
sync_story_status(story, image_status=StoryAssetStatus.READY)
await db.commit()
result = AssetCompletionResult(
asset="cover_image",
status=StoryAssetStatus.READY,
value=image_url,
blocks_main_result=raise_on_failure,
)
await TraceRecorder(db).record_step(
job=job,
story_id=story.id,
event_type="cover_image_succeeded",
status="succeeded",
message="Cover image was generated.",
metadata=asset_result_metadata(result),
)
return result
except Exception as exc:
provider_error = str(exc)
last_error = (
f"{last_error_prefix}: {provider_error}"
if last_error_prefix
else provider_error
)
sync_story_status(
story,
image_status=StoryAssetStatus.FAILED,
last_error=last_error,
)
await db.commit()
logger.warning(log_event, story_id=story.id, error=provider_error)
result = AssetCompletionResult(
asset="cover_image",
status=StoryAssetStatus.FAILED,
error=provider_error,
blocks_main_result=raise_on_failure,
)
await TraceRecorder(db).record_step(
job=job,
story_id=story.id,
event_type="cover_image_failed",
status="failed",
message="Cover image generation failed.",
metadata=asset_result_metadata(result),
)
if raise_on_failure:
raise HTTPException(
status_code=500,
detail=f"Image generation failed: {provider_error}",
) from exc
return result
async def read_cached_audio_asset(
story: Story,
db: AsyncSession,
*,
audio_cache_exists_func: AudioCacheExists,
read_audio_cache_func: AudioCacheReader,
) -> bytes | None:
"""Read cached audio or repair stale audio cache metadata."""
if story.audio_path and audio_cache_exists_func(story.audio_path):
if story.audio_status != StoryAssetStatus.READY.value:
sync_story_status(story, audio_status=StoryAssetStatus.READY)
await db.commit()
return read_audio_cache_func(story.audio_path)
if story.audio_path and not audio_cache_exists_func(story.audio_path):
logger.warning(
"story_audio_cache_missing",
story_id=story.id,
audio_path=story.audio_path,
)
story.audio_path = None
if story.audio_status == StoryAssetStatus.READY.value:
sync_story_status(story, audio_status=StoryAssetStatus.NOT_REQUESTED)
await db.commit()
return None
async def complete_audio_asset(
story: Story,
db: AsyncSession,
*,
text_to_speech_func: TTSGenerator,
audio_cache_exists_func: AudioCacheExists,
read_audio_cache_func: AudioCacheReader,
write_story_audio_cache_func: AudioCacheWriter,
raise_on_failure: bool = True,
job=None,
) -> AssetCompletionResult:
"""Complete TTS audio generation through one asset workflow."""
if not story.story_text:
raise HTTPException(status_code=400, detail="Story has no text")
cached_audio = await read_cached_audio_asset(
story,
db,
audio_cache_exists_func=audio_cache_exists_func,
read_audio_cache_func=read_audio_cache_func,
)
if cached_audio is not None:
result = AssetCompletionResult(
asset="audio",
status=StoryAssetStatus.READY,
value=cached_audio,
blocks_main_result=raise_on_failure,
)
await TraceRecorder(db).record_step(
job=job,
story_id=story.id,
event_type="audio_cache_hit",
status="succeeded",
message="Cached story audio was reused.",
metadata=asset_result_metadata(result),
)
return result
sync_story_status(story, audio_status=StoryAssetStatus.GENERATING)
await db.commit()
await ExecutionControl(db).stop_if_cancel_requested(job=job, story=story)
await TraceRecorder(db).record_step(
job=job,
story_id=story.id,
event_type="audio_started",
status="running",
message="Story audio generation started.",
metadata={"asset": "audio"},
)
try:
await ExecutionControl(db).stop_if_cancel_requested(job=job, story=story)
audio_data = await text_to_speech_func(
story.story_text,
db=db,
user_id=story.user_id,
generation_job=job,
story_id=story.id,
)
story.audio_path = write_story_audio_cache_func(story.id, audio_data)
sync_story_status(
story,
audio_status=StoryAssetStatus.READY,
)
await db.commit()
result = AssetCompletionResult(
asset="audio",
status=StoryAssetStatus.READY,
value=audio_data,
blocks_main_result=raise_on_failure,
)
await TraceRecorder(db).record_step(
job=job,
story_id=story.id,
event_type="audio_succeeded",
status="succeeded",
message="Story audio was generated and cached.",
metadata=asset_result_metadata(result),
)
return result
except Exception as exc:
provider_error = str(exc)
story.audio_path = None
sync_story_status(
story,
audio_status=StoryAssetStatus.FAILED,
last_error=provider_error,
)
await db.commit()
logger.error("audio_generation_failed", story_id=story.id, error=provider_error)
result = AssetCompletionResult(
asset="audio",
status=StoryAssetStatus.FAILED,
error=provider_error,
blocks_main_result=raise_on_failure,
)
await TraceRecorder(db).record_step(
job=job,
story_id=story.id,
event_type="audio_failed",
status="failed",
message="Story audio generation failed.",
metadata=asset_result_metadata(result),
)
if raise_on_failure:
raise HTTPException(
status_code=500,
detail=f"Audio generation failed: {provider_error}",
) from exc
return result
def get_storybook_pages_data(story: Story) -> list[dict]:
"""Return mutable storybook page data from the persisted JSON field."""
return [dict(page) for page in story.pages or [] if isinstance(page, dict)]
def build_storybook_error_message(
*,
cover_failed: bool,
failed_pages: list[int],
) -> str | None:
"""Summarize storybook image generation errors for the latest attempt."""
parts: list[str] = []
if cover_failed:
parts.append("封面生成失败")
if failed_pages:
pages = "".join(str(page) for page in sorted(failed_pages))
parts.append(f"{pages} 页插图生成失败")
return "".join(parts) if parts else None
def resolve_storybook_image_status(
*,
generate_images: bool,
cover_prompt: str | None,
cover_url: str | None,
pages_data: list[dict],
) -> StoryAssetStatus:
"""Resolve the persisted image status for a storybook."""
if not generate_images:
return StoryAssetStatus.NOT_REQUESTED
expected_assets = 0
ready_assets = 0
if cover_prompt or cover_url:
expected_assets += 1
if cover_url:
ready_assets += 1
for page in pages_data:
if not page.get("image_prompt") and not page.get("image_url"):
continue
expected_assets += 1
if page.get("image_url"):
ready_assets += 1
if expected_assets == 0:
return StoryAssetStatus.NOT_REQUESTED
if ready_assets == expected_assets:
return StoryAssetStatus.READY
return StoryAssetStatus.FAILED
async def complete_storybook_image_assets(
story: Story,
db: AsyncSession,
*,
generate_image_func: ImageGenerator,
job=None,
) -> AssetCompletionResult:
"""Complete missing cover/page images for a persisted storybook."""
pages_data = get_storybook_pages_data(story)
has_image_prompt = bool(story.cover_prompt) or any(
page.get("image_prompt") for page in pages_data
)
if not has_image_prompt:
raise HTTPException(status_code=400, detail="Storybook has no image prompts")
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
await db.commit()
await ExecutionControl(db).stop_if_cancel_requested(job=job, story=story)
await TraceRecorder(db).record_step(
job=job,
story_id=story.id,
event_type="storybook_images_started",
status="running",
message="Storybook missing image completion started.",
metadata={"asset": "image"},
)
cover_failed = False
failed_pages: list[int] = []
completed_pages: list[int] = []
if story.cover_prompt and not story.image_url:
await ExecutionControl(db).stop_if_cancel_requested(job=job, story=story)
try:
story.image_url = await generate_image_func(
story.cover_prompt,
db=db,
user_id=story.user_id,
generation_job=job,
story_id=story.id,
)
await TraceRecorder(db).record_step(
job=job,
story_id=story.id,
event_type="storybook_cover_image_succeeded",
status="succeeded",
message="Storybook cover image was generated.",
metadata={"asset": "image", "scope": "cover"},
)
except Exception as exc:
cover_failed = True
logger.warning(
"storybook_cover_asset_completion_failed",
story_id=story.id,
error=str(exc),
)
await TraceRecorder(db).record_step(
job=job,
story_id=story.id,
event_type="storybook_cover_image_failed",
status="failed",
message="Storybook cover image generation failed.",
metadata={"asset": "image", "scope": "cover", "error": str(exc)},
)
for page in pages_data:
if not page.get("image_prompt") or page.get("image_url"):
continue
await ExecutionControl(db).stop_if_cancel_requested(job=job, story=story)
try:
page["image_url"] = await generate_image_func(
page["image_prompt"],
db=db,
user_id=story.user_id,
generation_job=job,
story_id=story.id,
)
page_number = page.get("page_number")
if isinstance(page_number, int):
completed_pages.append(page_number)
await TraceRecorder(db).record_step(
job=job,
story_id=story.id,
event_type="storybook_page_image_succeeded",
status="succeeded",
message="Storybook page image was generated.",
metadata={"asset": "image", "scope": "page", "page_number": page_number},
)
except Exception as exc:
page_number = page.get("page_number")
if isinstance(page_number, int):
failed_pages.append(page_number)
logger.warning(
"storybook_page_asset_completion_failed",
story_id=story.id,
page=page_number,
error=str(exc),
)
await TraceRecorder(db).record_step(
job=job,
story_id=story.id,
event_type="storybook_page_image_failed",
status="failed",
message="Storybook page image generation failed.",
metadata={
"asset": "image",
"scope": "page",
"page_number": page_number,
"error": str(exc),
},
)
story.pages = pages_data
error_message = build_storybook_error_message(
cover_failed=cover_failed,
failed_pages=failed_pages,
)
image_status = resolve_storybook_image_status(
generate_images=True,
cover_prompt=story.cover_prompt,
cover_url=story.image_url,
pages_data=pages_data,
)
sync_story_status(
story,
image_status=image_status,
last_error=error_message,
)
await db.commit()
result = AssetCompletionResult(
asset="storybook_images",
status=image_status,
value=story.image_url,
error=error_message,
)
await TraceRecorder(db).record_step(
job=job,
story_id=story.id,
event_type="storybook_images_completed",
status="failed" if error_message else "succeeded",
message="Storybook image completion finished.",
metadata={
**asset_result_metadata(result),
"completed_pages": sorted(completed_pages),
"failed_pages": sorted(failed_pages),
},
)
return result