feat: add generation job cancel and retry queue

This commit is contained in:
2026-04-19 18:45:34 +08:00
parent 6fb128955f
commit b89ca96e4b
18 changed files with 756 additions and 51 deletions

View File

@@ -16,6 +16,26 @@ from app.db.models import GenerationJob, GenerationJobEvent, Story
logger = get_logger(__name__)
def _is_terminal_status(status: str) -> bool:
return status in {"completed", "degraded_completed", "failed", "canceled"}
def _job_supports_queue_control(job: GenerationJob) -> bool:
return job.output_mode in {"story", "storybook"}
def generation_job_can_cancel(job: GenerationJob) -> bool:
return (
_job_supports_queue_control(job)
and job.status == "running"
and job.current_step != "cancel_requested"
)
def generation_job_can_retry(job: GenerationJob) -> bool:
return _job_supports_queue_control(job) and job.status in {"failed", "canceled"}
def _story_snapshot(story: Story | None) -> dict[str, Any]:
if story is None:
return {}
@@ -50,6 +70,13 @@ def _job_progress(job: GenerationJob) -> dict[str, Any]:
"is_terminal": True,
}
if job.status == "canceled":
return {
"progress_percent": 100,
"progress_label": "已取消",
"is_terminal": True,
}
if job.status in {"completed", "degraded_completed"}:
return {
"progress_percent": 100,
@@ -59,7 +86,9 @@ def _job_progress(job: GenerationJob) -> dict[str, Any]:
progress_map: dict[str, tuple[int, str]] = {
"request_accepted": (5, "已接收请求"),
"retry_queued": (8, "重新排队中"),
"worker_started": (12, "后台任务已开始"),
"cancel_requested": (15, "已请求取消"),
"context_prepared": (20, "上下文已准备"),
"narrative_generated": (45, "正文已生成"),
"story_saved": (60, "主记录已保存"),
@@ -83,6 +112,7 @@ def _job_progress(job: GenerationJob) -> dict[str, Any]:
"postprocessing_queued": (90, "后处理已排队"),
"asset_generation_completed": (100, "资源已完成"),
"asset_retry_completed": (100, "资源重试完成"),
"generation_canceled": (100, "任务已取消"),
"generation_completed": (100, "生成完成"),
"generation_stale_failed": (100, "任务超时已收敛"),
}
@@ -106,6 +136,8 @@ def _is_stale_job(job: GenerationJob, *, stale_after_minutes: int) -> bool:
def _failure_label(job: GenerationJob) -> str:
if job.status == "canceled":
return "任务已取消"
if job.current_step == "generation_stale_failed":
return "任务超时"
if job.output_mode == "asset_retry":
@@ -196,7 +228,7 @@ async def claim_generation_job_for_worker(
.where(
GenerationJob.id == job_id,
GenerationJob.status == "running",
GenerationJob.current_step == "request_accepted",
GenerationJob.current_step.in_(["request_accepted", "retry_queued"]),
)
.values(current_step="worker_started")
)
@@ -283,6 +315,8 @@ def generation_job_to_summary(job: GenerationJob) -> dict[str, Any]:
"status": job.status,
"current_step": job.current_step,
**progress,
"can_cancel": generation_job_can_cancel(job),
"can_retry": generation_job_can_retry(job),
"result_snapshot": job.result_snapshot or {},
"error_message": job.error_message,
"created_at": job.created_at,
@@ -290,6 +324,88 @@ def generation_job_to_summary(job: GenerationJob) -> dict[str, Any]:
}
async def get_generation_job_for_user(
db: AsyncSession,
*,
job_id: str,
user_id: str,
) -> GenerationJob:
"""Load one generation job owned by the current user."""
result = await db.execute(
select(GenerationJob).where(
GenerationJob.id == job_id,
GenerationJob.user_id == user_id,
)
)
job = result.scalar_one_or_none()
if job is None:
raise HTTPException(status_code=404, detail="Generation job not found")
return job
async def request_generation_job_cancel(
db: AsyncSession,
*,
job_id: str,
user_id: str,
) -> dict[str, Any]:
"""Request cancellation for one queued/running generation job."""
job = await get_generation_job_for_user(db, job_id=job_id, user_id=user_id)
if not _job_supports_queue_control(job):
raise HTTPException(status_code=409, detail="当前任务不支持取消")
if job.status == "canceled":
return generation_job_to_summary(job)
if _is_terminal_status(job.status):
raise HTTPException(status_code=409, detail="当前任务已终止,无法取消")
if job.current_step == "cancel_requested":
return generation_job_to_summary(job)
if job.current_step in {"request_accepted", "retry_queued"}:
story = None
if job.story_id is not None:
story = (
await db.execute(
select(Story).where(
Story.id == job.story_id,
Story.user_id == job.user_id,
)
)
).scalar_one_or_none()
await finish_generation_job(
db,
job=job,
story=story,
status="canceled",
current_step="generation_canceled",
error_message="Generation canceled by user before worker execution started.",
message="Generation job was canceled before worker execution started.",
)
return generation_job_to_summary(job)
previous_step = job.current_step
job.error_message = "Cancellation requested by user."
await record_generation_event(
db,
job=job,
story_id=job.story_id,
event_type="cancel_requested",
status="running",
message="Cancellation requested; worker will stop at the next safe checkpoint.",
metadata={"requested_from_step": previous_step},
commit=False,
)
await db.commit()
await db.refresh(job)
return generation_job_to_summary(job)
async def get_generation_job_detail(
db: AsyncSession,
*,