feat: add generation job cancel and retry queue
This commit is contained in:
@@ -40,6 +40,7 @@ from app.services.generation_jobs import (
|
||||
get_user_generation_ops_summary,
|
||||
get_user_provider_analytics,
|
||||
list_story_generation_jobs,
|
||||
request_generation_job_cancel,
|
||||
)
|
||||
from app.services.memory_service import build_enhanced_memory_context
|
||||
from app.services.provider_router import (
|
||||
@@ -88,6 +89,32 @@ async def get_generation_job(
|
||||
return await get_generation_job_detail(db, job_id=job_id, user_id=user.id)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/generations/jobs/{job_id}/cancel",
|
||||
response_model=GenerationJobSummaryResponse,
|
||||
)
|
||||
async def cancel_generation_job(
|
||||
job_id: str,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Request cancellation for one queued/running generation job."""
|
||||
return await request_generation_job_cancel(db, job_id=job_id, user_id=user.id)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/generations/jobs/{job_id}/retry",
|
||||
response_model=GenerationJobSummaryResponse,
|
||||
)
|
||||
async def retry_generation_job(
|
||||
job_id: str,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Queue one new generation job from a failed/canceled terminal job."""
|
||||
return await story_service.retry_generation_job_service(job_id, user.id, db)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/generations/ops-summary",
|
||||
response_model=GenerationOpsSummaryResponse,
|
||||
|
||||
@@ -195,6 +195,8 @@ class GenerationJobSummaryResponse(BaseModel):
|
||||
progress_percent: int
|
||||
progress_label: str
|
||||
is_terminal: bool
|
||||
can_cancel: bool = False
|
||||
can_retry: bool = False
|
||||
result_snapshot: dict[str, Any] = Field(default_factory=dict)
|
||||
error_message: str | None = None
|
||||
created_at: datetime
|
||||
|
||||
@@ -16,6 +16,26 @@ from app.db.models import GenerationJob, GenerationJobEvent, Story
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _is_terminal_status(status: str) -> bool:
|
||||
return status in {"completed", "degraded_completed", "failed", "canceled"}
|
||||
|
||||
|
||||
def _job_supports_queue_control(job: GenerationJob) -> bool:
|
||||
return job.output_mode in {"story", "storybook"}
|
||||
|
||||
|
||||
def generation_job_can_cancel(job: GenerationJob) -> bool:
|
||||
return (
|
||||
_job_supports_queue_control(job)
|
||||
and job.status == "running"
|
||||
and job.current_step != "cancel_requested"
|
||||
)
|
||||
|
||||
|
||||
def generation_job_can_retry(job: GenerationJob) -> bool:
|
||||
return _job_supports_queue_control(job) and job.status in {"failed", "canceled"}
|
||||
|
||||
|
||||
def _story_snapshot(story: Story | None) -> dict[str, Any]:
|
||||
if story is None:
|
||||
return {}
|
||||
@@ -50,6 +70,13 @@ def _job_progress(job: GenerationJob) -> dict[str, Any]:
|
||||
"is_terminal": True,
|
||||
}
|
||||
|
||||
if job.status == "canceled":
|
||||
return {
|
||||
"progress_percent": 100,
|
||||
"progress_label": "已取消",
|
||||
"is_terminal": True,
|
||||
}
|
||||
|
||||
if job.status in {"completed", "degraded_completed"}:
|
||||
return {
|
||||
"progress_percent": 100,
|
||||
@@ -59,7 +86,9 @@ def _job_progress(job: GenerationJob) -> dict[str, Any]:
|
||||
|
||||
progress_map: dict[str, tuple[int, str]] = {
|
||||
"request_accepted": (5, "已接收请求"),
|
||||
"retry_queued": (8, "重新排队中"),
|
||||
"worker_started": (12, "后台任务已开始"),
|
||||
"cancel_requested": (15, "已请求取消"),
|
||||
"context_prepared": (20, "上下文已准备"),
|
||||
"narrative_generated": (45, "正文已生成"),
|
||||
"story_saved": (60, "主记录已保存"),
|
||||
@@ -83,6 +112,7 @@ def _job_progress(job: GenerationJob) -> dict[str, Any]:
|
||||
"postprocessing_queued": (90, "后处理已排队"),
|
||||
"asset_generation_completed": (100, "资源已完成"),
|
||||
"asset_retry_completed": (100, "资源重试完成"),
|
||||
"generation_canceled": (100, "任务已取消"),
|
||||
"generation_completed": (100, "生成完成"),
|
||||
"generation_stale_failed": (100, "任务超时已收敛"),
|
||||
}
|
||||
@@ -106,6 +136,8 @@ def _is_stale_job(job: GenerationJob, *, stale_after_minutes: int) -> bool:
|
||||
|
||||
|
||||
def _failure_label(job: GenerationJob) -> str:
|
||||
if job.status == "canceled":
|
||||
return "任务已取消"
|
||||
if job.current_step == "generation_stale_failed":
|
||||
return "任务超时"
|
||||
if job.output_mode == "asset_retry":
|
||||
@@ -196,7 +228,7 @@ async def claim_generation_job_for_worker(
|
||||
.where(
|
||||
GenerationJob.id == job_id,
|
||||
GenerationJob.status == "running",
|
||||
GenerationJob.current_step == "request_accepted",
|
||||
GenerationJob.current_step.in_(["request_accepted", "retry_queued"]),
|
||||
)
|
||||
.values(current_step="worker_started")
|
||||
)
|
||||
@@ -283,6 +315,8 @@ def generation_job_to_summary(job: GenerationJob) -> dict[str, Any]:
|
||||
"status": job.status,
|
||||
"current_step": job.current_step,
|
||||
**progress,
|
||||
"can_cancel": generation_job_can_cancel(job),
|
||||
"can_retry": generation_job_can_retry(job),
|
||||
"result_snapshot": job.result_snapshot or {},
|
||||
"error_message": job.error_message,
|
||||
"created_at": job.created_at,
|
||||
@@ -290,6 +324,88 @@ def generation_job_to_summary(job: GenerationJob) -> dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
async def get_generation_job_for_user(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
job_id: str,
|
||||
user_id: str,
|
||||
) -> GenerationJob:
|
||||
"""Load one generation job owned by the current user."""
|
||||
|
||||
result = await db.execute(
|
||||
select(GenerationJob).where(
|
||||
GenerationJob.id == job_id,
|
||||
GenerationJob.user_id == user_id,
|
||||
)
|
||||
)
|
||||
job = result.scalar_one_or_none()
|
||||
if job is None:
|
||||
raise HTTPException(status_code=404, detail="Generation job not found")
|
||||
return job
|
||||
|
||||
|
||||
async def request_generation_job_cancel(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
job_id: str,
|
||||
user_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Request cancellation for one queued/running generation job."""
|
||||
|
||||
job = await get_generation_job_for_user(db, job_id=job_id, user_id=user_id)
|
||||
|
||||
if not _job_supports_queue_control(job):
|
||||
raise HTTPException(status_code=409, detail="当前任务不支持取消")
|
||||
|
||||
if job.status == "canceled":
|
||||
return generation_job_to_summary(job)
|
||||
|
||||
if _is_terminal_status(job.status):
|
||||
raise HTTPException(status_code=409, detail="当前任务已终止,无法取消")
|
||||
|
||||
if job.current_step == "cancel_requested":
|
||||
return generation_job_to_summary(job)
|
||||
|
||||
if job.current_step in {"request_accepted", "retry_queued"}:
|
||||
story = None
|
||||
if job.story_id is not None:
|
||||
story = (
|
||||
await db.execute(
|
||||
select(Story).where(
|
||||
Story.id == job.story_id,
|
||||
Story.user_id == job.user_id,
|
||||
)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
await finish_generation_job(
|
||||
db,
|
||||
job=job,
|
||||
story=story,
|
||||
status="canceled",
|
||||
current_step="generation_canceled",
|
||||
error_message="Generation canceled by user before worker execution started.",
|
||||
message="Generation job was canceled before worker execution started.",
|
||||
)
|
||||
return generation_job_to_summary(job)
|
||||
|
||||
previous_step = job.current_step
|
||||
job.error_message = "Cancellation requested by user."
|
||||
await record_generation_event(
|
||||
db,
|
||||
job=job,
|
||||
story_id=job.story_id,
|
||||
event_type="cancel_requested",
|
||||
status="running",
|
||||
message="Cancellation requested; worker will stop at the next safe checkpoint.",
|
||||
metadata={"requested_from_step": previous_step},
|
||||
commit=False,
|
||||
)
|
||||
await db.commit()
|
||||
await db.refresh(job)
|
||||
return generation_job_to_summary(job)
|
||||
|
||||
|
||||
async def get_generation_job_detail(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
|
||||
@@ -37,6 +37,9 @@ from app.services.generation_jobs import (
|
||||
create_generation_job,
|
||||
ensure_no_active_story_generation_job,
|
||||
finish_generation_job,
|
||||
generation_job_can_retry,
|
||||
generation_job_to_summary,
|
||||
get_generation_job_for_user,
|
||||
record_generation_event,
|
||||
)
|
||||
from app.services.memory_service import build_enhanced_memory_context
|
||||
@@ -73,6 +76,10 @@ class AssetCompletionResult:
|
||||
return self.status == StoryAssetStatus.READY and self.error is None
|
||||
|
||||
|
||||
class GenerationJobCanceledError(Exception):
|
||||
"""Raised when a running worker job has been canceled by the user."""
|
||||
|
||||
|
||||
async def _record_job_event_if_present(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
@@ -99,6 +106,33 @@ async def _record_job_event_if_present(
|
||||
)
|
||||
|
||||
|
||||
async def _stop_if_job_cancel_requested(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
job,
|
||||
story: Story | None = None,
|
||||
) -> None:
|
||||
"""Stop a worker-owned job at the next safe checkpoint after cancellation."""
|
||||
|
||||
if job is None:
|
||||
return
|
||||
|
||||
await db.refresh(job)
|
||||
if job.current_step != "cancel_requested":
|
||||
return
|
||||
|
||||
await finish_generation_job(
|
||||
db,
|
||||
job=job,
|
||||
story=story,
|
||||
status="canceled",
|
||||
current_step="generation_canceled",
|
||||
error_message="Generation canceled by user.",
|
||||
message="Generation job was canceled after a user request.",
|
||||
)
|
||||
raise GenerationJobCanceledError()
|
||||
|
||||
|
||||
def _asset_result_metadata(result: AssetCompletionResult) -> dict:
|
||||
"""Build JSON-safe metadata for asset workflow events."""
|
||||
|
||||
@@ -192,6 +226,7 @@ async def _prepare_generation_context(
|
||||
"has_memory_context": bool(memory_context),
|
||||
},
|
||||
)
|
||||
await _stop_if_job_cancel_requested(db, job=job)
|
||||
return resolved_profile_id, resolved_universe_id, memory_context
|
||||
|
||||
|
||||
@@ -318,6 +353,7 @@ async def _generate_storybook_image_assets(
|
||||
]
|
||||
|
||||
logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages))
|
||||
await _stop_if_job_cancel_requested(db, job=job)
|
||||
await _record_job_event_if_present(
|
||||
db,
|
||||
job=job,
|
||||
@@ -334,6 +370,7 @@ async def _generate_storybook_image_assets(
|
||||
nonlocal cover_failed
|
||||
|
||||
if storybook.cover_prompt and not storybook.cover_url:
|
||||
await _stop_if_job_cancel_requested(db, job=job)
|
||||
try:
|
||||
return await generate_image(
|
||||
storybook.cover_prompt,
|
||||
@@ -350,6 +387,7 @@ async def _generate_storybook_image_assets(
|
||||
if not page.image_prompt or page.image_url:
|
||||
return
|
||||
|
||||
await _stop_if_job_cancel_requested(db, job=job)
|
||||
try:
|
||||
page.image_url = await generate_image(
|
||||
page.image_prompt,
|
||||
@@ -506,6 +544,7 @@ async def _complete_cover_image_asset(
|
||||
|
||||
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
|
||||
await db.commit()
|
||||
await _stop_if_job_cancel_requested(db, job=job, story=story)
|
||||
await _record_job_event_if_present(
|
||||
db,
|
||||
job=job,
|
||||
@@ -517,6 +556,7 @@ async def _complete_cover_image_asset(
|
||||
)
|
||||
|
||||
try:
|
||||
await _stop_if_job_cancel_requested(db, job=job, story=story)
|
||||
image_url = await generate_image(
|
||||
story.cover_prompt,
|
||||
db=db,
|
||||
@@ -605,6 +645,7 @@ async def _complete_storybook_image_assets(
|
||||
|
||||
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
|
||||
await db.commit()
|
||||
await _stop_if_job_cancel_requested(db, job=job, story=story)
|
||||
await _record_job_event_if_present(
|
||||
db,
|
||||
job=job,
|
||||
@@ -620,6 +661,7 @@ async def _complete_storybook_image_assets(
|
||||
completed_pages: list[int] = []
|
||||
|
||||
if story.cover_prompt and not story.image_url:
|
||||
await _stop_if_job_cancel_requested(db, job=job, story=story)
|
||||
try:
|
||||
story.image_url = await generate_image(
|
||||
story.cover_prompt,
|
||||
@@ -658,6 +700,7 @@ async def _complete_storybook_image_assets(
|
||||
if not page.get("image_prompt") or page.get("image_url"):
|
||||
continue
|
||||
|
||||
await _stop_if_job_cancel_requested(db, job=job, story=story)
|
||||
try:
|
||||
page["image_url"] = await generate_image(
|
||||
page["image_prompt"],
|
||||
@@ -800,6 +843,7 @@ async def _complete_audio_asset(
|
||||
|
||||
sync_story_status(story, audio_status=StoryAssetStatus.GENERATING)
|
||||
await db.commit()
|
||||
await _stop_if_job_cancel_requested(db, job=job, story=story)
|
||||
await _record_job_event_if_present(
|
||||
db,
|
||||
job=job,
|
||||
@@ -811,6 +855,7 @@ async def _complete_audio_asset(
|
||||
)
|
||||
|
||||
try:
|
||||
await _stop_if_job_cancel_requested(db, job=job, story=story)
|
||||
audio_data = await text_to_speech(
|
||||
story.story_text,
|
||||
db=db,
|
||||
@@ -933,6 +978,7 @@ async def generate_and_save_story(
|
||||
)
|
||||
|
||||
try:
|
||||
await _stop_if_job_cancel_requested(db, job=job)
|
||||
result = await generate_story_content(
|
||||
input_type=request.type,
|
||||
data=request.data,
|
||||
@@ -955,8 +1001,9 @@ async def generate_and_save_story(
|
||||
message="Story narrative was generated.",
|
||||
metadata={"mode": result.mode, "title": result.title},
|
||||
)
|
||||
await _stop_if_job_cancel_requested(db, job=job)
|
||||
|
||||
return await _persist_text_story_result(
|
||||
story = await _persist_text_story_result(
|
||||
result=result,
|
||||
user_id=user_id,
|
||||
profile_id=profile_id,
|
||||
@@ -964,6 +1011,8 @@ async def generate_and_save_story(
|
||||
db=db,
|
||||
job=job,
|
||||
)
|
||||
await _stop_if_job_cancel_requested(db, job=job, story=story)
|
||||
return story
|
||||
|
||||
|
||||
async def generate_full_story_service(
|
||||
@@ -975,6 +1024,7 @@ async def generate_full_story_service(
|
||||
) -> FullStoryResponse:
|
||||
"""Generate story with parallel image generation."""
|
||||
story = await generate_and_save_story(request, user_id, db, job=job)
|
||||
await _stop_if_job_cancel_requested(db, job=job, story=story)
|
||||
image_url: str | None = None
|
||||
errors: dict[str, str | None] = {}
|
||||
|
||||
@@ -1036,6 +1086,7 @@ async def generate_storybook_service(
|
||||
)
|
||||
|
||||
try:
|
||||
await _stop_if_job_cancel_requested(db, job=job)
|
||||
storybook = await generate_storybook(
|
||||
keywords=request.keywords,
|
||||
page_count=request.page_count,
|
||||
@@ -1060,12 +1111,14 @@ async def generate_storybook_service(
|
||||
"page_count": len(storybook.pages),
|
||||
},
|
||||
)
|
||||
await _stop_if_job_cancel_requested(db, job=job)
|
||||
|
||||
final_cover_url = storybook.cover_url
|
||||
cover_failed = False
|
||||
failed_pages: list[int] = []
|
||||
|
||||
if request.generate_images:
|
||||
await _stop_if_job_cancel_requested(db, job=job)
|
||||
(
|
||||
final_cover_url,
|
||||
cover_failed,
|
||||
@@ -1089,6 +1142,7 @@ async def generate_storybook_service(
|
||||
db=db,
|
||||
job=job,
|
||||
)
|
||||
await _stop_if_job_cancel_requested(db, job=job, story=story)
|
||||
|
||||
response_pages = _storybook_pages_to_response(pages_data)
|
||||
|
||||
@@ -1124,6 +1178,18 @@ async def generate_generation_service(
|
||||
request_payload=request.model_dump(mode="json"),
|
||||
)
|
||||
|
||||
await _dispatch_generation_job(db, job=job)
|
||||
|
||||
return _build_queued_generation_response(request, job_id=job.id)
|
||||
|
||||
|
||||
async def _dispatch_generation_job(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
job: GenerationJob,
|
||||
) -> None:
|
||||
"""Dispatch one accepted generation job to the background worker."""
|
||||
|
||||
try:
|
||||
from app.tasks.generation_workflow import run_generation_workflow_task
|
||||
|
||||
@@ -1144,8 +1210,6 @@ async def generate_generation_service(
|
||||
detail="后台生成任务派发失败,请确认 worker 可用后重试。",
|
||||
) from exc
|
||||
|
||||
return _build_queued_generation_response(request, job_id=job.id)
|
||||
|
||||
|
||||
def _build_queued_generation_response(
|
||||
request: GenerationRequest,
|
||||
@@ -1184,6 +1248,8 @@ async def execute_generation_job_service(
|
||||
db,
|
||||
job=job,
|
||||
)
|
||||
except GenerationJobCanceledError:
|
||||
return _build_canceled_generation_response(job)
|
||||
except HTTPException as exc:
|
||||
await finish_generation_job(
|
||||
db,
|
||||
@@ -1210,6 +1276,24 @@ async def execute_generation_job_service(
|
||||
return response
|
||||
|
||||
|
||||
def _build_canceled_generation_response(job: GenerationJob) -> GenerationResponse:
|
||||
"""Build a compact response for a worker job that ended as canceled."""
|
||||
|
||||
snapshot = job.result_snapshot or {}
|
||||
return GenerationResponse(
|
||||
id=snapshot.get("story_id"),
|
||||
generation_job_id=job.id,
|
||||
title="生成任务已取消",
|
||||
mode="storybook" if job.output_mode == "storybook" else "generated",
|
||||
generation_status=str(snapshot.get("generation_status") or "failed"),
|
||||
text_status=str(snapshot.get("text_status") or "failed"),
|
||||
image_status=str(snapshot.get("image_status") or "not_requested"),
|
||||
audio_status=str(snapshot.get("audio_status") or "not_requested"),
|
||||
last_error=str(snapshot.get("last_error") or "Generation canceled by user."),
|
||||
retryable_assets=list(snapshot.get("retryable_assets") or []),
|
||||
)
|
||||
|
||||
|
||||
async def run_generation_job_service(
|
||||
job_id: str,
|
||||
db: AsyncSession,
|
||||
@@ -1225,6 +1309,46 @@ async def run_generation_job_service(
|
||||
return job
|
||||
|
||||
|
||||
async def retry_generation_job_service(
|
||||
job_id: str,
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
) -> dict:
|
||||
"""Clone one failed/canceled generation job and queue it again."""
|
||||
|
||||
source_job = await get_generation_job_for_user(db, job_id=job_id, user_id=user_id)
|
||||
if not generation_job_can_retry(source_job):
|
||||
raise HTTPException(status_code=409, detail="当前任务还不能重新排队")
|
||||
|
||||
if source_job.story_id is not None:
|
||||
await ensure_no_active_story_generation_job(
|
||||
db,
|
||||
story_id=source_job.story_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
retry_job = await create_generation_job(
|
||||
db,
|
||||
user_id=user_id,
|
||||
output_mode=source_job.output_mode,
|
||||
input_type=source_job.input_type,
|
||||
request_payload=source_job.request_payload or {},
|
||||
story_id=source_job.story_id,
|
||||
)
|
||||
await record_generation_event(
|
||||
db,
|
||||
job=retry_job,
|
||||
story_id=retry_job.story_id,
|
||||
event_type="retry_queued",
|
||||
status="queued",
|
||||
message="Retry job accepted from a previous terminal generation.",
|
||||
metadata={"source_job_id": source_job.id},
|
||||
)
|
||||
await _dispatch_generation_job(db, job=retry_job)
|
||||
await db.refresh(retry_job)
|
||||
return generation_job_to_summary(retry_job)
|
||||
|
||||
|
||||
async def _generate_generation_service_with_job(
|
||||
request: GenerationRequest,
|
||||
user_id: str,
|
||||
|
||||
@@ -15,6 +15,8 @@ from app.services.adapters.storybook.primary import Storybook, StorybookPage
|
||||
from app.services.adapters.text.models import StoryOutput
|
||||
from app.services.generation_jobs import (
|
||||
create_generation_job,
|
||||
finish_generation_job,
|
||||
get_generation_job_detail,
|
||||
mark_stale_generation_jobs,
|
||||
record_generation_event,
|
||||
)
|
||||
@@ -847,3 +849,187 @@ async def test_retry_assets_rejects_when_story_has_active_job(
|
||||
assert "已有运行中的任务" in response.json()["detail"]
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
async def test_cancel_queued_generation_job_marks_it_canceled(
|
||||
db_session,
|
||||
auth_token,
|
||||
test_user,
|
||||
):
|
||||
async def override_get_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
transport = ASGITransport(app=app)
|
||||
task_delay_path = "app.tasks.generation_workflow.run_generation_workflow_task.delay"
|
||||
|
||||
try:
|
||||
with patch(task_delay_path) as mock_delay:
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
client.cookies.set("access_token", auth_token)
|
||||
|
||||
response = await client.post(
|
||||
"/api/generations",
|
||||
json={
|
||||
"output_mode": "story",
|
||||
"type": "keywords",
|
||||
"data": "小狐狸, 月亮船",
|
||||
"generate_images": False,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 202
|
||||
job_id = response.json()["generation_job_id"]
|
||||
mock_delay.assert_called_once_with(job_id)
|
||||
|
||||
cancel_response = await client.post(f"/api/generations/jobs/{job_id}/cancel")
|
||||
|
||||
assert cancel_response.status_code == 200
|
||||
canceled_job = cancel_response.json()
|
||||
assert canceled_job["status"] == "canceled"
|
||||
assert canceled_job["current_step"] == "generation_canceled"
|
||||
assert canceled_job["can_cancel"] is False
|
||||
assert canceled_job["can_retry"] is True
|
||||
|
||||
detail = await get_generation_job_detail(
|
||||
db_session,
|
||||
job_id=job_id,
|
||||
user_id=test_user.id,
|
||||
)
|
||||
assert [event["event_type"] for event in detail["events"]] == [
|
||||
"request_accepted",
|
||||
"generation_canceled",
|
||||
]
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
async def test_cancel_running_generation_job_marks_cancel_requested(
|
||||
db_session,
|
||||
auth_token,
|
||||
test_user,
|
||||
):
|
||||
async def override_get_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
transport = ASGITransport(app=app)
|
||||
|
||||
job = await create_generation_job(
|
||||
db_session,
|
||||
user_id=test_user.id,
|
||||
output_mode="story",
|
||||
input_type="keywords",
|
||||
request_payload={
|
||||
"output_mode": "story",
|
||||
"type": "keywords",
|
||||
"data": "小熊, 森林",
|
||||
"generate_images": False,
|
||||
},
|
||||
)
|
||||
await record_generation_event(
|
||||
db_session,
|
||||
job=job,
|
||||
event_type="worker_started",
|
||||
status="running",
|
||||
message="Generation worker started processing the accepted request.",
|
||||
)
|
||||
|
||||
try:
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
client.cookies.set("access_token", auth_token)
|
||||
|
||||
response = await client.post(f"/api/generations/jobs/{job.id}/cancel")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "running"
|
||||
assert data["current_step"] == "cancel_requested"
|
||||
assert data["can_cancel"] is False
|
||||
assert data["can_retry"] is False
|
||||
|
||||
refreshed_job = (
|
||||
await db_session.execute(select(GenerationJob).where(GenerationJob.id == job.id))
|
||||
).scalar_one()
|
||||
assert refreshed_job.current_step == "cancel_requested"
|
||||
assert refreshed_job.error_message == "Cancellation requested by user."
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
async def test_retry_failed_generation_job_requeues_new_worker_job(
|
||||
db_session,
|
||||
auth_token,
|
||||
test_user,
|
||||
mock_text_provider,
|
||||
):
|
||||
async def override_get_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
transport = ASGITransport(app=app)
|
||||
task_delay_path = "app.tasks.generation_workflow.run_generation_workflow_task.delay"
|
||||
|
||||
failed_job = await create_generation_job(
|
||||
db_session,
|
||||
user_id=test_user.id,
|
||||
output_mode="story",
|
||||
input_type="keywords",
|
||||
request_payload={
|
||||
"output_mode": "story",
|
||||
"type": "keywords",
|
||||
"data": "小鹿, 星星",
|
||||
"generate_images": False,
|
||||
},
|
||||
)
|
||||
await finish_generation_job(
|
||||
db_session,
|
||||
job=failed_job,
|
||||
story=None,
|
||||
status="failed",
|
||||
current_step="generation_failed",
|
||||
error_message="upstream timeout",
|
||||
message="Generation failed before a durable story result was available.",
|
||||
)
|
||||
|
||||
try:
|
||||
with patch(task_delay_path) as mock_delay:
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
client.cookies.set("access_token", auth_token)
|
||||
|
||||
response = await client.post(f"/api/generations/jobs/{failed_job.id}/retry")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] != failed_job.id
|
||||
assert data["status"] == "running"
|
||||
assert data["current_step"] == "retry_queued"
|
||||
assert data["can_cancel"] is True
|
||||
assert data["can_retry"] is False
|
||||
mock_delay.assert_called_once_with(data["id"])
|
||||
|
||||
retried_job_id = data["id"]
|
||||
await run_generation_job_service(retried_job_id, db_session)
|
||||
|
||||
retried_job = (
|
||||
await db_session.execute(
|
||||
select(GenerationJob).where(GenerationJob.id == retried_job_id)
|
||||
)
|
||||
).scalar_one()
|
||||
assert retried_job.status == "completed"
|
||||
assert retried_job.current_step == "generation_completed"
|
||||
|
||||
events = (
|
||||
await db_session.execute(
|
||||
select(GenerationJobEvent)
|
||||
.where(GenerationJobEvent.job_id == retried_job_id)
|
||||
.order_by(GenerationJobEvent.id)
|
||||
)
|
||||
).scalars().all()
|
||||
assert [event.event_type for event in events[:3]] == [
|
||||
"request_accepted",
|
||||
"retry_queued",
|
||||
"worker_started",
|
||||
]
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
Reference in New Issue
Block a user