feat: add generation job cancel and retry queue

This commit is contained in:
2026-04-19 18:45:34 +08:00
parent 6fb128955f
commit b89ca96e4b
18 changed files with 756 additions and 51 deletions

View File

@@ -37,6 +37,9 @@ from app.services.generation_jobs import (
create_generation_job,
ensure_no_active_story_generation_job,
finish_generation_job,
generation_job_can_retry,
generation_job_to_summary,
get_generation_job_for_user,
record_generation_event,
)
from app.services.memory_service import build_enhanced_memory_context
@@ -73,6 +76,10 @@ class AssetCompletionResult:
return self.status == StoryAssetStatus.READY and self.error is None
class GenerationJobCanceledError(Exception):
"""Raised when a running worker job has been canceled by the user."""
async def _record_job_event_if_present(
db: AsyncSession,
*,
@@ -99,6 +106,33 @@ async def _record_job_event_if_present(
)
async def _stop_if_job_cancel_requested(
db: AsyncSession,
*,
job,
story: Story | None = None,
) -> None:
"""Stop a worker-owned job at the next safe checkpoint after cancellation."""
if job is None:
return
await db.refresh(job)
if job.current_step != "cancel_requested":
return
await finish_generation_job(
db,
job=job,
story=story,
status="canceled",
current_step="generation_canceled",
error_message="Generation canceled by user.",
message="Generation job was canceled after a user request.",
)
raise GenerationJobCanceledError()
def _asset_result_metadata(result: AssetCompletionResult) -> dict:
"""Build JSON-safe metadata for asset workflow events."""
@@ -192,6 +226,7 @@ async def _prepare_generation_context(
"has_memory_context": bool(memory_context),
},
)
await _stop_if_job_cancel_requested(db, job=job)
return resolved_profile_id, resolved_universe_id, memory_context
@@ -318,6 +353,7 @@ async def _generate_storybook_image_assets(
]
logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages))
await _stop_if_job_cancel_requested(db, job=job)
await _record_job_event_if_present(
db,
job=job,
@@ -334,6 +370,7 @@ async def _generate_storybook_image_assets(
nonlocal cover_failed
if storybook.cover_prompt and not storybook.cover_url:
await _stop_if_job_cancel_requested(db, job=job)
try:
return await generate_image(
storybook.cover_prompt,
@@ -350,6 +387,7 @@ async def _generate_storybook_image_assets(
if not page.image_prompt or page.image_url:
return
await _stop_if_job_cancel_requested(db, job=job)
try:
page.image_url = await generate_image(
page.image_prompt,
@@ -506,6 +544,7 @@ async def _complete_cover_image_asset(
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
await db.commit()
await _stop_if_job_cancel_requested(db, job=job, story=story)
await _record_job_event_if_present(
db,
job=job,
@@ -517,6 +556,7 @@ async def _complete_cover_image_asset(
)
try:
await _stop_if_job_cancel_requested(db, job=job, story=story)
image_url = await generate_image(
story.cover_prompt,
db=db,
@@ -605,6 +645,7 @@ async def _complete_storybook_image_assets(
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
await db.commit()
await _stop_if_job_cancel_requested(db, job=job, story=story)
await _record_job_event_if_present(
db,
job=job,
@@ -620,6 +661,7 @@ async def _complete_storybook_image_assets(
completed_pages: list[int] = []
if story.cover_prompt and not story.image_url:
await _stop_if_job_cancel_requested(db, job=job, story=story)
try:
story.image_url = await generate_image(
story.cover_prompt,
@@ -658,6 +700,7 @@ async def _complete_storybook_image_assets(
if not page.get("image_prompt") or page.get("image_url"):
continue
await _stop_if_job_cancel_requested(db, job=job, story=story)
try:
page["image_url"] = await generate_image(
page["image_prompt"],
@@ -800,6 +843,7 @@ async def _complete_audio_asset(
sync_story_status(story, audio_status=StoryAssetStatus.GENERATING)
await db.commit()
await _stop_if_job_cancel_requested(db, job=job, story=story)
await _record_job_event_if_present(
db,
job=job,
@@ -811,6 +855,7 @@ async def _complete_audio_asset(
)
try:
await _stop_if_job_cancel_requested(db, job=job, story=story)
audio_data = await text_to_speech(
story.story_text,
db=db,
@@ -933,6 +978,7 @@ async def generate_and_save_story(
)
try:
await _stop_if_job_cancel_requested(db, job=job)
result = await generate_story_content(
input_type=request.type,
data=request.data,
@@ -955,8 +1001,9 @@ async def generate_and_save_story(
message="Story narrative was generated.",
metadata={"mode": result.mode, "title": result.title},
)
await _stop_if_job_cancel_requested(db, job=job)
return await _persist_text_story_result(
story = await _persist_text_story_result(
result=result,
user_id=user_id,
profile_id=profile_id,
@@ -964,6 +1011,8 @@ async def generate_and_save_story(
db=db,
job=job,
)
await _stop_if_job_cancel_requested(db, job=job, story=story)
return story
async def generate_full_story_service(
@@ -975,6 +1024,7 @@ async def generate_full_story_service(
) -> FullStoryResponse:
"""Generate story with parallel image generation."""
story = await generate_and_save_story(request, user_id, db, job=job)
await _stop_if_job_cancel_requested(db, job=job, story=story)
image_url: str | None = None
errors: dict[str, str | None] = {}
@@ -1036,6 +1086,7 @@ async def generate_storybook_service(
)
try:
await _stop_if_job_cancel_requested(db, job=job)
storybook = await generate_storybook(
keywords=request.keywords,
page_count=request.page_count,
@@ -1060,12 +1111,14 @@ async def generate_storybook_service(
"page_count": len(storybook.pages),
},
)
await _stop_if_job_cancel_requested(db, job=job)
final_cover_url = storybook.cover_url
cover_failed = False
failed_pages: list[int] = []
if request.generate_images:
await _stop_if_job_cancel_requested(db, job=job)
(
final_cover_url,
cover_failed,
@@ -1089,6 +1142,7 @@ async def generate_storybook_service(
db=db,
job=job,
)
await _stop_if_job_cancel_requested(db, job=job, story=story)
response_pages = _storybook_pages_to_response(pages_data)
@@ -1124,6 +1178,18 @@ async def generate_generation_service(
request_payload=request.model_dump(mode="json"),
)
await _dispatch_generation_job(db, job=job)
return _build_queued_generation_response(request, job_id=job.id)
async def _dispatch_generation_job(
db: AsyncSession,
*,
job: GenerationJob,
) -> None:
"""Dispatch one accepted generation job to the background worker."""
try:
from app.tasks.generation_workflow import run_generation_workflow_task
@@ -1144,8 +1210,6 @@ async def generate_generation_service(
detail="后台生成任务派发失败,请确认 worker 可用后重试。",
) from exc
return _build_queued_generation_response(request, job_id=job.id)
def _build_queued_generation_response(
request: GenerationRequest,
@@ -1184,6 +1248,8 @@ async def execute_generation_job_service(
db,
job=job,
)
except GenerationJobCanceledError:
return _build_canceled_generation_response(job)
except HTTPException as exc:
await finish_generation_job(
db,
@@ -1210,6 +1276,24 @@ async def execute_generation_job_service(
return response
def _build_canceled_generation_response(job: GenerationJob) -> GenerationResponse:
"""Build a compact response for a worker job that ended as canceled."""
snapshot = job.result_snapshot or {}
return GenerationResponse(
id=snapshot.get("story_id"),
generation_job_id=job.id,
title="生成任务已取消",
mode="storybook" if job.output_mode == "storybook" else "generated",
generation_status=str(snapshot.get("generation_status") or "failed"),
text_status=str(snapshot.get("text_status") or "failed"),
image_status=str(snapshot.get("image_status") or "not_requested"),
audio_status=str(snapshot.get("audio_status") or "not_requested"),
last_error=str(snapshot.get("last_error") or "Generation canceled by user."),
retryable_assets=list(snapshot.get("retryable_assets") or []),
)
async def run_generation_job_service(
job_id: str,
db: AsyncSession,
@@ -1225,6 +1309,46 @@ async def run_generation_job_service(
return job
async def retry_generation_job_service(
job_id: str,
user_id: str,
db: AsyncSession,
) -> dict:
"""Clone one failed/canceled generation job and queue it again."""
source_job = await get_generation_job_for_user(db, job_id=job_id, user_id=user_id)
if not generation_job_can_retry(source_job):
raise HTTPException(status_code=409, detail="当前任务还不能重新排队")
if source_job.story_id is not None:
await ensure_no_active_story_generation_job(
db,
story_id=source_job.story_id,
user_id=user_id,
)
retry_job = await create_generation_job(
db,
user_id=user_id,
output_mode=source_job.output_mode,
input_type=source_job.input_type,
request_payload=source_job.request_payload or {},
story_id=source_job.story_id,
)
await record_generation_event(
db,
job=retry_job,
story_id=retry_job.story_id,
event_type="retry_queued",
status="queued",
message="Retry job accepted from a previous terminal generation.",
metadata={"source_job_id": source_job.id},
)
await _dispatch_generation_job(db, job=retry_job)
await db.refresh(retry_job)
return generation_job_to_summary(retry_job)
async def _generate_generation_service_with_job(
request: GenerationRequest,
user_id: str,