feat: add generation job cancel and retry queue
This commit is contained in:
@@ -15,6 +15,8 @@ from app.services.adapters.storybook.primary import Storybook, StorybookPage
|
||||
from app.services.adapters.text.models import StoryOutput
|
||||
from app.services.generation_jobs import (
|
||||
create_generation_job,
|
||||
finish_generation_job,
|
||||
get_generation_job_detail,
|
||||
mark_stale_generation_jobs,
|
||||
record_generation_event,
|
||||
)
|
||||
@@ -847,3 +849,187 @@ async def test_retry_assets_rejects_when_story_has_active_job(
|
||||
assert "已有运行中的任务" in response.json()["detail"]
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
async def test_cancel_queued_generation_job_marks_it_canceled(
|
||||
db_session,
|
||||
auth_token,
|
||||
test_user,
|
||||
):
|
||||
async def override_get_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
transport = ASGITransport(app=app)
|
||||
task_delay_path = "app.tasks.generation_workflow.run_generation_workflow_task.delay"
|
||||
|
||||
try:
|
||||
with patch(task_delay_path) as mock_delay:
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
client.cookies.set("access_token", auth_token)
|
||||
|
||||
response = await client.post(
|
||||
"/api/generations",
|
||||
json={
|
||||
"output_mode": "story",
|
||||
"type": "keywords",
|
||||
"data": "小狐狸, 月亮船",
|
||||
"generate_images": False,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 202
|
||||
job_id = response.json()["generation_job_id"]
|
||||
mock_delay.assert_called_once_with(job_id)
|
||||
|
||||
cancel_response = await client.post(f"/api/generations/jobs/{job_id}/cancel")
|
||||
|
||||
assert cancel_response.status_code == 200
|
||||
canceled_job = cancel_response.json()
|
||||
assert canceled_job["status"] == "canceled"
|
||||
assert canceled_job["current_step"] == "generation_canceled"
|
||||
assert canceled_job["can_cancel"] is False
|
||||
assert canceled_job["can_retry"] is True
|
||||
|
||||
detail = await get_generation_job_detail(
|
||||
db_session,
|
||||
job_id=job_id,
|
||||
user_id=test_user.id,
|
||||
)
|
||||
assert [event["event_type"] for event in detail["events"]] == [
|
||||
"request_accepted",
|
||||
"generation_canceled",
|
||||
]
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
async def test_cancel_running_generation_job_marks_cancel_requested(
|
||||
db_session,
|
||||
auth_token,
|
||||
test_user,
|
||||
):
|
||||
async def override_get_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
transport = ASGITransport(app=app)
|
||||
|
||||
job = await create_generation_job(
|
||||
db_session,
|
||||
user_id=test_user.id,
|
||||
output_mode="story",
|
||||
input_type="keywords",
|
||||
request_payload={
|
||||
"output_mode": "story",
|
||||
"type": "keywords",
|
||||
"data": "小熊, 森林",
|
||||
"generate_images": False,
|
||||
},
|
||||
)
|
||||
await record_generation_event(
|
||||
db_session,
|
||||
job=job,
|
||||
event_type="worker_started",
|
||||
status="running",
|
||||
message="Generation worker started processing the accepted request.",
|
||||
)
|
||||
|
||||
try:
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
client.cookies.set("access_token", auth_token)
|
||||
|
||||
response = await client.post(f"/api/generations/jobs/{job.id}/cancel")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "running"
|
||||
assert data["current_step"] == "cancel_requested"
|
||||
assert data["can_cancel"] is False
|
||||
assert data["can_retry"] is False
|
||||
|
||||
refreshed_job = (
|
||||
await db_session.execute(select(GenerationJob).where(GenerationJob.id == job.id))
|
||||
).scalar_one()
|
||||
assert refreshed_job.current_step == "cancel_requested"
|
||||
assert refreshed_job.error_message == "Cancellation requested by user."
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
async def test_retry_failed_generation_job_requeues_new_worker_job(
|
||||
db_session,
|
||||
auth_token,
|
||||
test_user,
|
||||
mock_text_provider,
|
||||
):
|
||||
async def override_get_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
transport = ASGITransport(app=app)
|
||||
task_delay_path = "app.tasks.generation_workflow.run_generation_workflow_task.delay"
|
||||
|
||||
failed_job = await create_generation_job(
|
||||
db_session,
|
||||
user_id=test_user.id,
|
||||
output_mode="story",
|
||||
input_type="keywords",
|
||||
request_payload={
|
||||
"output_mode": "story",
|
||||
"type": "keywords",
|
||||
"data": "小鹿, 星星",
|
||||
"generate_images": False,
|
||||
},
|
||||
)
|
||||
await finish_generation_job(
|
||||
db_session,
|
||||
job=failed_job,
|
||||
story=None,
|
||||
status="failed",
|
||||
current_step="generation_failed",
|
||||
error_message="upstream timeout",
|
||||
message="Generation failed before a durable story result was available.",
|
||||
)
|
||||
|
||||
try:
|
||||
with patch(task_delay_path) as mock_delay:
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
client.cookies.set("access_token", auth_token)
|
||||
|
||||
response = await client.post(f"/api/generations/jobs/{failed_job.id}/retry")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] != failed_job.id
|
||||
assert data["status"] == "running"
|
||||
assert data["current_step"] == "retry_queued"
|
||||
assert data["can_cancel"] is True
|
||||
assert data["can_retry"] is False
|
||||
mock_delay.assert_called_once_with(data["id"])
|
||||
|
||||
retried_job_id = data["id"]
|
||||
await run_generation_job_service(retried_job_id, db_session)
|
||||
|
||||
retried_job = (
|
||||
await db_session.execute(
|
||||
select(GenerationJob).where(GenerationJob.id == retried_job_id)
|
||||
)
|
||||
).scalar_one()
|
||||
assert retried_job.status == "completed"
|
||||
assert retried_job.current_step == "generation_completed"
|
||||
|
||||
events = (
|
||||
await db_session.execute(
|
||||
select(GenerationJobEvent)
|
||||
.where(GenerationJobEvent.job_id == retried_job_id)
|
||||
.order_by(GenerationJobEvent.id)
|
||||
)
|
||||
).scalars().all()
|
||||
assert [event.event_type for event in events[:3]] == [
|
||||
"request_accepted",
|
||||
"retry_queued",
|
||||
"worker_started",
|
||||
]
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
Reference in New Issue
Block a user