feat: add generation job cancel and retry queue
This commit is contained in:
@@ -16,6 +16,26 @@ from app.db.models import GenerationJob, GenerationJobEvent, Story
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _is_terminal_status(status: str) -> bool:
|
||||
return status in {"completed", "degraded_completed", "failed", "canceled"}
|
||||
|
||||
|
||||
def _job_supports_queue_control(job: GenerationJob) -> bool:
|
||||
return job.output_mode in {"story", "storybook"}
|
||||
|
||||
|
||||
def generation_job_can_cancel(job: GenerationJob) -> bool:
|
||||
return (
|
||||
_job_supports_queue_control(job)
|
||||
and job.status == "running"
|
||||
and job.current_step != "cancel_requested"
|
||||
)
|
||||
|
||||
|
||||
def generation_job_can_retry(job: GenerationJob) -> bool:
|
||||
return _job_supports_queue_control(job) and job.status in {"failed", "canceled"}
|
||||
|
||||
|
||||
def _story_snapshot(story: Story | None) -> dict[str, Any]:
|
||||
if story is None:
|
||||
return {}
|
||||
@@ -50,6 +70,13 @@ def _job_progress(job: GenerationJob) -> dict[str, Any]:
|
||||
"is_terminal": True,
|
||||
}
|
||||
|
||||
if job.status == "canceled":
|
||||
return {
|
||||
"progress_percent": 100,
|
||||
"progress_label": "已取消",
|
||||
"is_terminal": True,
|
||||
}
|
||||
|
||||
if job.status in {"completed", "degraded_completed"}:
|
||||
return {
|
||||
"progress_percent": 100,
|
||||
@@ -59,7 +86,9 @@ def _job_progress(job: GenerationJob) -> dict[str, Any]:
|
||||
|
||||
progress_map: dict[str, tuple[int, str]] = {
|
||||
"request_accepted": (5, "已接收请求"),
|
||||
"retry_queued": (8, "重新排队中"),
|
||||
"worker_started": (12, "后台任务已开始"),
|
||||
"cancel_requested": (15, "已请求取消"),
|
||||
"context_prepared": (20, "上下文已准备"),
|
||||
"narrative_generated": (45, "正文已生成"),
|
||||
"story_saved": (60, "主记录已保存"),
|
||||
@@ -83,6 +112,7 @@ def _job_progress(job: GenerationJob) -> dict[str, Any]:
|
||||
"postprocessing_queued": (90, "后处理已排队"),
|
||||
"asset_generation_completed": (100, "资源已完成"),
|
||||
"asset_retry_completed": (100, "资源重试完成"),
|
||||
"generation_canceled": (100, "任务已取消"),
|
||||
"generation_completed": (100, "生成完成"),
|
||||
"generation_stale_failed": (100, "任务超时已收敛"),
|
||||
}
|
||||
@@ -106,6 +136,8 @@ def _is_stale_job(job: GenerationJob, *, stale_after_minutes: int) -> bool:
|
||||
|
||||
|
||||
def _failure_label(job: GenerationJob) -> str:
|
||||
if job.status == "canceled":
|
||||
return "任务已取消"
|
||||
if job.current_step == "generation_stale_failed":
|
||||
return "任务超时"
|
||||
if job.output_mode == "asset_retry":
|
||||
@@ -196,7 +228,7 @@ async def claim_generation_job_for_worker(
|
||||
.where(
|
||||
GenerationJob.id == job_id,
|
||||
GenerationJob.status == "running",
|
||||
GenerationJob.current_step == "request_accepted",
|
||||
GenerationJob.current_step.in_(["request_accepted", "retry_queued"]),
|
||||
)
|
||||
.values(current_step="worker_started")
|
||||
)
|
||||
@@ -283,6 +315,8 @@ def generation_job_to_summary(job: GenerationJob) -> dict[str, Any]:
|
||||
"status": job.status,
|
||||
"current_step": job.current_step,
|
||||
**progress,
|
||||
"can_cancel": generation_job_can_cancel(job),
|
||||
"can_retry": generation_job_can_retry(job),
|
||||
"result_snapshot": job.result_snapshot or {},
|
||||
"error_message": job.error_message,
|
||||
"created_at": job.created_at,
|
||||
@@ -290,6 +324,88 @@ def generation_job_to_summary(job: GenerationJob) -> dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
async def get_generation_job_for_user(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
job_id: str,
|
||||
user_id: str,
|
||||
) -> GenerationJob:
|
||||
"""Load one generation job owned by the current user."""
|
||||
|
||||
result = await db.execute(
|
||||
select(GenerationJob).where(
|
||||
GenerationJob.id == job_id,
|
||||
GenerationJob.user_id == user_id,
|
||||
)
|
||||
)
|
||||
job = result.scalar_one_or_none()
|
||||
if job is None:
|
||||
raise HTTPException(status_code=404, detail="Generation job not found")
|
||||
return job
|
||||
|
||||
|
||||
async def request_generation_job_cancel(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
job_id: str,
|
||||
user_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Request cancellation for one queued/running generation job."""
|
||||
|
||||
job = await get_generation_job_for_user(db, job_id=job_id, user_id=user_id)
|
||||
|
||||
if not _job_supports_queue_control(job):
|
||||
raise HTTPException(status_code=409, detail="当前任务不支持取消")
|
||||
|
||||
if job.status == "canceled":
|
||||
return generation_job_to_summary(job)
|
||||
|
||||
if _is_terminal_status(job.status):
|
||||
raise HTTPException(status_code=409, detail="当前任务已终止,无法取消")
|
||||
|
||||
if job.current_step == "cancel_requested":
|
||||
return generation_job_to_summary(job)
|
||||
|
||||
if job.current_step in {"request_accepted", "retry_queued"}:
|
||||
story = None
|
||||
if job.story_id is not None:
|
||||
story = (
|
||||
await db.execute(
|
||||
select(Story).where(
|
||||
Story.id == job.story_id,
|
||||
Story.user_id == job.user_id,
|
||||
)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
await finish_generation_job(
|
||||
db,
|
||||
job=job,
|
||||
story=story,
|
||||
status="canceled",
|
||||
current_step="generation_canceled",
|
||||
error_message="Generation canceled by user before worker execution started.",
|
||||
message="Generation job was canceled before worker execution started.",
|
||||
)
|
||||
return generation_job_to_summary(job)
|
||||
|
||||
previous_step = job.current_step
|
||||
job.error_message = "Cancellation requested by user."
|
||||
await record_generation_event(
|
||||
db,
|
||||
job=job,
|
||||
story_id=job.story_id,
|
||||
event_type="cancel_requested",
|
||||
status="running",
|
||||
message="Cancellation requested; worker will stop at the next safe checkpoint.",
|
||||
metadata={"requested_from_step": previous_step},
|
||||
commit=False,
|
||||
)
|
||||
await db.commit()
|
||||
await db.refresh(job)
|
||||
return generation_job_to_summary(job)
|
||||
|
||||
|
||||
async def get_generation_job_detail(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
|
||||
Reference in New Issue
Block a user