feat: add generation job cancel and retry queue

This commit is contained in:
2026-04-19 18:45:34 +08:00
parent 6fb128955f
commit b89ca96e4b
18 changed files with 756 additions and 51 deletions

View File

@@ -15,6 +15,8 @@ from app.services.adapters.storybook.primary import Storybook, StorybookPage
from app.services.adapters.text.models import StoryOutput
from app.services.generation_jobs import (
create_generation_job,
finish_generation_job,
get_generation_job_detail,
mark_stale_generation_jobs,
record_generation_event,
)
@@ -847,3 +849,187 @@ async def test_retry_assets_rejects_when_story_has_active_job(
assert "已有运行中的任务" in response.json()["detail"]
finally:
app.dependency_overrides.clear()
async def test_cancel_queued_generation_job_marks_it_canceled(
db_session,
auth_token,
test_user,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
task_delay_path = "app.tasks.generation_workflow.run_generation_workflow_task.delay"
try:
with patch(task_delay_path) as mock_delay:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post(
"/api/generations",
json={
"output_mode": "story",
"type": "keywords",
"data": "小狐狸, 月亮船",
"generate_images": False,
},
)
assert response.status_code == 202
job_id = response.json()["generation_job_id"]
mock_delay.assert_called_once_with(job_id)
cancel_response = await client.post(f"/api/generations/jobs/{job_id}/cancel")
assert cancel_response.status_code == 200
canceled_job = cancel_response.json()
assert canceled_job["status"] == "canceled"
assert canceled_job["current_step"] == "generation_canceled"
assert canceled_job["can_cancel"] is False
assert canceled_job["can_retry"] is True
detail = await get_generation_job_detail(
db_session,
job_id=job_id,
user_id=test_user.id,
)
assert [event["event_type"] for event in detail["events"]] == [
"request_accepted",
"generation_canceled",
]
finally:
app.dependency_overrides.clear()
async def test_cancel_running_generation_job_marks_cancel_requested(
db_session,
auth_token,
test_user,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
job = await create_generation_job(
db_session,
user_id=test_user.id,
output_mode="story",
input_type="keywords",
request_payload={
"output_mode": "story",
"type": "keywords",
"data": "小熊, 森林",
"generate_images": False,
},
)
await record_generation_event(
db_session,
job=job,
event_type="worker_started",
status="running",
message="Generation worker started processing the accepted request.",
)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post(f"/api/generations/jobs/{job.id}/cancel")
assert response.status_code == 200
data = response.json()
assert data["status"] == "running"
assert data["current_step"] == "cancel_requested"
assert data["can_cancel"] is False
assert data["can_retry"] is False
refreshed_job = (
await db_session.execute(select(GenerationJob).where(GenerationJob.id == job.id))
).scalar_one()
assert refreshed_job.current_step == "cancel_requested"
assert refreshed_job.error_message == "Cancellation requested by user."
finally:
app.dependency_overrides.clear()
async def test_retry_failed_generation_job_requeues_new_worker_job(
db_session,
auth_token,
test_user,
mock_text_provider,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
task_delay_path = "app.tasks.generation_workflow.run_generation_workflow_task.delay"
failed_job = await create_generation_job(
db_session,
user_id=test_user.id,
output_mode="story",
input_type="keywords",
request_payload={
"output_mode": "story",
"type": "keywords",
"data": "小鹿, 星星",
"generate_images": False,
},
)
await finish_generation_job(
db_session,
job=failed_job,
story=None,
status="failed",
current_step="generation_failed",
error_message="upstream timeout",
message="Generation failed before a durable story result was available.",
)
try:
with patch(task_delay_path) as mock_delay:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post(f"/api/generations/jobs/{failed_job.id}/retry")
assert response.status_code == 200
data = response.json()
assert data["id"] != failed_job.id
assert data["status"] == "running"
assert data["current_step"] == "retry_queued"
assert data["can_cancel"] is True
assert data["can_retry"] is False
mock_delay.assert_called_once_with(data["id"])
retried_job_id = data["id"]
await run_generation_job_service(retried_job_id, db_session)
retried_job = (
await db_session.execute(
select(GenerationJob).where(GenerationJob.id == retried_job_id)
)
).scalar_one()
assert retried_job.status == "completed"
assert retried_job.current_step == "generation_completed"
events = (
await db_session.execute(
select(GenerationJobEvent)
.where(GenerationJobEvent.job_id == retried_job_id)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
assert [event.event_type for event in events[:3]] == [
"request_accepted",
"retry_queued",
"worker_started",
]
finally:
app.dependency_overrides.clear()