feat: add generation job cancel and retry queue

This commit is contained in:
2026-04-19 18:45:34 +08:00
parent 6fb128955f
commit b89ca96e4b
18 changed files with 756 additions and 51 deletions

View File

@@ -16,6 +16,26 @@ from app.db.models import GenerationJob, GenerationJobEvent, Story
logger = get_logger(__name__)
def _is_terminal_status(status: str) -> bool:
return status in {"completed", "degraded_completed", "failed", "canceled"}
def _job_supports_queue_control(job: GenerationJob) -> bool:
return job.output_mode in {"story", "storybook"}
def generation_job_can_cancel(job: GenerationJob) -> bool:
return (
_job_supports_queue_control(job)
and job.status == "running"
and job.current_step != "cancel_requested"
)
def generation_job_can_retry(job: GenerationJob) -> bool:
return _job_supports_queue_control(job) and job.status in {"failed", "canceled"}
def _story_snapshot(story: Story | None) -> dict[str, Any]:
if story is None:
return {}
@@ -50,6 +70,13 @@ def _job_progress(job: GenerationJob) -> dict[str, Any]:
"is_terminal": True,
}
if job.status == "canceled":
return {
"progress_percent": 100,
"progress_label": "已取消",
"is_terminal": True,
}
if job.status in {"completed", "degraded_completed"}:
return {
"progress_percent": 100,
@@ -59,7 +86,9 @@ def _job_progress(job: GenerationJob) -> dict[str, Any]:
progress_map: dict[str, tuple[int, str]] = {
"request_accepted": (5, "已接收请求"),
"retry_queued": (8, "重新排队中"),
"worker_started": (12, "后台任务已开始"),
"cancel_requested": (15, "已请求取消"),
"context_prepared": (20, "上下文已准备"),
"narrative_generated": (45, "正文已生成"),
"story_saved": (60, "主记录已保存"),
@@ -83,6 +112,7 @@ def _job_progress(job: GenerationJob) -> dict[str, Any]:
"postprocessing_queued": (90, "后处理已排队"),
"asset_generation_completed": (100, "资源已完成"),
"asset_retry_completed": (100, "资源重试完成"),
"generation_canceled": (100, "任务已取消"),
"generation_completed": (100, "生成完成"),
"generation_stale_failed": (100, "任务超时已收敛"),
}
@@ -106,6 +136,8 @@ def _is_stale_job(job: GenerationJob, *, stale_after_minutes: int) -> bool:
def _failure_label(job: GenerationJob) -> str:
if job.status == "canceled":
return "任务已取消"
if job.current_step == "generation_stale_failed":
return "任务超时"
if job.output_mode == "asset_retry":
@@ -196,7 +228,7 @@ async def claim_generation_job_for_worker(
.where(
GenerationJob.id == job_id,
GenerationJob.status == "running",
GenerationJob.current_step == "request_accepted",
GenerationJob.current_step.in_(["request_accepted", "retry_queued"]),
)
.values(current_step="worker_started")
)
@@ -283,6 +315,8 @@ def generation_job_to_summary(job: GenerationJob) -> dict[str, Any]:
"status": job.status,
"current_step": job.current_step,
**progress,
"can_cancel": generation_job_can_cancel(job),
"can_retry": generation_job_can_retry(job),
"result_snapshot": job.result_snapshot or {},
"error_message": job.error_message,
"created_at": job.created_at,
@@ -290,6 +324,88 @@ def generation_job_to_summary(job: GenerationJob) -> dict[str, Any]:
}
async def get_generation_job_for_user(
db: AsyncSession,
*,
job_id: str,
user_id: str,
) -> GenerationJob:
"""Load one generation job owned by the current user."""
result = await db.execute(
select(GenerationJob).where(
GenerationJob.id == job_id,
GenerationJob.user_id == user_id,
)
)
job = result.scalar_one_or_none()
if job is None:
raise HTTPException(status_code=404, detail="Generation job not found")
return job
async def request_generation_job_cancel(
db: AsyncSession,
*,
job_id: str,
user_id: str,
) -> dict[str, Any]:
"""Request cancellation for one queued/running generation job."""
job = await get_generation_job_for_user(db, job_id=job_id, user_id=user_id)
if not _job_supports_queue_control(job):
raise HTTPException(status_code=409, detail="当前任务不支持取消")
if job.status == "canceled":
return generation_job_to_summary(job)
if _is_terminal_status(job.status):
raise HTTPException(status_code=409, detail="当前任务已终止,无法取消")
if job.current_step == "cancel_requested":
return generation_job_to_summary(job)
if job.current_step in {"request_accepted", "retry_queued"}:
story = None
if job.story_id is not None:
story = (
await db.execute(
select(Story).where(
Story.id == job.story_id,
Story.user_id == job.user_id,
)
)
).scalar_one_or_none()
await finish_generation_job(
db,
job=job,
story=story,
status="canceled",
current_step="generation_canceled",
error_message="Generation canceled by user before worker execution started.",
message="Generation job was canceled before worker execution started.",
)
return generation_job_to_summary(job)
previous_step = job.current_step
job.error_message = "Cancellation requested by user."
await record_generation_event(
db,
job=job,
story_id=job.story_id,
event_type="cancel_requested",
status="running",
message="Cancellation requested; worker will stop at the next safe checkpoint.",
metadata={"requested_from_step": previous_step},
commit=False,
)
await db.commit()
await db.refresh(job)
return generation_job_to_summary(job)
async def get_generation_job_detail(
db: AsyncSession,
*,

View File

@@ -37,6 +37,9 @@ from app.services.generation_jobs import (
create_generation_job,
ensure_no_active_story_generation_job,
finish_generation_job,
generation_job_can_retry,
generation_job_to_summary,
get_generation_job_for_user,
record_generation_event,
)
from app.services.memory_service import build_enhanced_memory_context
@@ -73,6 +76,10 @@ class AssetCompletionResult:
return self.status == StoryAssetStatus.READY and self.error is None
class GenerationJobCanceledError(Exception):
"""Raised when a running worker job has been canceled by the user."""
async def _record_job_event_if_present(
db: AsyncSession,
*,
@@ -99,6 +106,33 @@ async def _record_job_event_if_present(
)
async def _stop_if_job_cancel_requested(
db: AsyncSession,
*,
job,
story: Story | None = None,
) -> None:
"""Stop a worker-owned job at the next safe checkpoint after cancellation."""
if job is None:
return
await db.refresh(job)
if job.current_step != "cancel_requested":
return
await finish_generation_job(
db,
job=job,
story=story,
status="canceled",
current_step="generation_canceled",
error_message="Generation canceled by user.",
message="Generation job was canceled after a user request.",
)
raise GenerationJobCanceledError()
def _asset_result_metadata(result: AssetCompletionResult) -> dict:
"""Build JSON-safe metadata for asset workflow events."""
@@ -192,6 +226,7 @@ async def _prepare_generation_context(
"has_memory_context": bool(memory_context),
},
)
await _stop_if_job_cancel_requested(db, job=job)
return resolved_profile_id, resolved_universe_id, memory_context
@@ -318,6 +353,7 @@ async def _generate_storybook_image_assets(
]
logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages))
await _stop_if_job_cancel_requested(db, job=job)
await _record_job_event_if_present(
db,
job=job,
@@ -334,6 +370,7 @@ async def _generate_storybook_image_assets(
nonlocal cover_failed
if storybook.cover_prompt and not storybook.cover_url:
await _stop_if_job_cancel_requested(db, job=job)
try:
return await generate_image(
storybook.cover_prompt,
@@ -350,6 +387,7 @@ async def _generate_storybook_image_assets(
if not page.image_prompt or page.image_url:
return
await _stop_if_job_cancel_requested(db, job=job)
try:
page.image_url = await generate_image(
page.image_prompt,
@@ -506,6 +544,7 @@ async def _complete_cover_image_asset(
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
await db.commit()
await _stop_if_job_cancel_requested(db, job=job, story=story)
await _record_job_event_if_present(
db,
job=job,
@@ -517,6 +556,7 @@ async def _complete_cover_image_asset(
)
try:
await _stop_if_job_cancel_requested(db, job=job, story=story)
image_url = await generate_image(
story.cover_prompt,
db=db,
@@ -605,6 +645,7 @@ async def _complete_storybook_image_assets(
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
await db.commit()
await _stop_if_job_cancel_requested(db, job=job, story=story)
await _record_job_event_if_present(
db,
job=job,
@@ -620,6 +661,7 @@ async def _complete_storybook_image_assets(
completed_pages: list[int] = []
if story.cover_prompt and not story.image_url:
await _stop_if_job_cancel_requested(db, job=job, story=story)
try:
story.image_url = await generate_image(
story.cover_prompt,
@@ -658,6 +700,7 @@ async def _complete_storybook_image_assets(
if not page.get("image_prompt") or page.get("image_url"):
continue
await _stop_if_job_cancel_requested(db, job=job, story=story)
try:
page["image_url"] = await generate_image(
page["image_prompt"],
@@ -800,6 +843,7 @@ async def _complete_audio_asset(
sync_story_status(story, audio_status=StoryAssetStatus.GENERATING)
await db.commit()
await _stop_if_job_cancel_requested(db, job=job, story=story)
await _record_job_event_if_present(
db,
job=job,
@@ -811,6 +855,7 @@ async def _complete_audio_asset(
)
try:
await _stop_if_job_cancel_requested(db, job=job, story=story)
audio_data = await text_to_speech(
story.story_text,
db=db,
@@ -933,6 +978,7 @@ async def generate_and_save_story(
)
try:
await _stop_if_job_cancel_requested(db, job=job)
result = await generate_story_content(
input_type=request.type,
data=request.data,
@@ -955,8 +1001,9 @@ async def generate_and_save_story(
message="Story narrative was generated.",
metadata={"mode": result.mode, "title": result.title},
)
await _stop_if_job_cancel_requested(db, job=job)
return await _persist_text_story_result(
story = await _persist_text_story_result(
result=result,
user_id=user_id,
profile_id=profile_id,
@@ -964,6 +1011,8 @@ async def generate_and_save_story(
db=db,
job=job,
)
await _stop_if_job_cancel_requested(db, job=job, story=story)
return story
async def generate_full_story_service(
@@ -975,6 +1024,7 @@ async def generate_full_story_service(
) -> FullStoryResponse:
"""Generate story with parallel image generation."""
story = await generate_and_save_story(request, user_id, db, job=job)
await _stop_if_job_cancel_requested(db, job=job, story=story)
image_url: str | None = None
errors: dict[str, str | None] = {}
@@ -1036,6 +1086,7 @@ async def generate_storybook_service(
)
try:
await _stop_if_job_cancel_requested(db, job=job)
storybook = await generate_storybook(
keywords=request.keywords,
page_count=request.page_count,
@@ -1060,12 +1111,14 @@ async def generate_storybook_service(
"page_count": len(storybook.pages),
},
)
await _stop_if_job_cancel_requested(db, job=job)
final_cover_url = storybook.cover_url
cover_failed = False
failed_pages: list[int] = []
if request.generate_images:
await _stop_if_job_cancel_requested(db, job=job)
(
final_cover_url,
cover_failed,
@@ -1089,6 +1142,7 @@ async def generate_storybook_service(
db=db,
job=job,
)
await _stop_if_job_cancel_requested(db, job=job, story=story)
response_pages = _storybook_pages_to_response(pages_data)
@@ -1124,6 +1178,18 @@ async def generate_generation_service(
request_payload=request.model_dump(mode="json"),
)
await _dispatch_generation_job(db, job=job)
return _build_queued_generation_response(request, job_id=job.id)
async def _dispatch_generation_job(
db: AsyncSession,
*,
job: GenerationJob,
) -> None:
"""Dispatch one accepted generation job to the background worker."""
try:
from app.tasks.generation_workflow import run_generation_workflow_task
@@ -1144,8 +1210,6 @@ async def generate_generation_service(
detail="后台生成任务派发失败,请确认 worker 可用后重试。",
) from exc
return _build_queued_generation_response(request, job_id=job.id)
def _build_queued_generation_response(
request: GenerationRequest,
@@ -1184,6 +1248,8 @@ async def execute_generation_job_service(
db,
job=job,
)
except GenerationJobCanceledError:
return _build_canceled_generation_response(job)
except HTTPException as exc:
await finish_generation_job(
db,
@@ -1210,6 +1276,24 @@ async def execute_generation_job_service(
return response
def _build_canceled_generation_response(job: GenerationJob) -> GenerationResponse:
"""Build a compact response for a worker job that ended as canceled."""
snapshot = job.result_snapshot or {}
return GenerationResponse(
id=snapshot.get("story_id"),
generation_job_id=job.id,
title="生成任务已取消",
mode="storybook" if job.output_mode == "storybook" else "generated",
generation_status=str(snapshot.get("generation_status") or "failed"),
text_status=str(snapshot.get("text_status") or "failed"),
image_status=str(snapshot.get("image_status") or "not_requested"),
audio_status=str(snapshot.get("audio_status") or "not_requested"),
last_error=str(snapshot.get("last_error") or "Generation canceled by user."),
retryable_assets=list(snapshot.get("retryable_assets") or []),
)
async def run_generation_job_service(
job_id: str,
db: AsyncSession,
@@ -1225,6 +1309,46 @@ async def run_generation_job_service(
return job
async def retry_generation_job_service(
job_id: str,
user_id: str,
db: AsyncSession,
) -> dict:
"""Clone one failed/canceled generation job and queue it again."""
source_job = await get_generation_job_for_user(db, job_id=job_id, user_id=user_id)
if not generation_job_can_retry(source_job):
raise HTTPException(status_code=409, detail="当前任务还不能重新排队")
if source_job.story_id is not None:
await ensure_no_active_story_generation_job(
db,
story_id=source_job.story_id,
user_id=user_id,
)
retry_job = await create_generation_job(
db,
user_id=user_id,
output_mode=source_job.output_mode,
input_type=source_job.input_type,
request_payload=source_job.request_payload or {},
story_id=source_job.story_id,
)
await record_generation_event(
db,
job=retry_job,
story_id=retry_job.story_id,
event_type="retry_queued",
status="queued",
message="Retry job accepted from a previous terminal generation.",
metadata={"source_job_id": source_job.id},
)
await _dispatch_generation_job(db, job=retry_job)
await db.refresh(retry_job)
return generation_job_to_summary(retry_job)
async def _generate_generation_service_with_job(
request: GenerationRequest,
user_id: str,