feat: add generation job cancel and retry queue
This commit is contained in:
@@ -37,6 +37,9 @@ from app.services.generation_jobs import (
|
||||
create_generation_job,
|
||||
ensure_no_active_story_generation_job,
|
||||
finish_generation_job,
|
||||
generation_job_can_retry,
|
||||
generation_job_to_summary,
|
||||
get_generation_job_for_user,
|
||||
record_generation_event,
|
||||
)
|
||||
from app.services.memory_service import build_enhanced_memory_context
|
||||
@@ -73,6 +76,10 @@ class AssetCompletionResult:
|
||||
return self.status == StoryAssetStatus.READY and self.error is None
|
||||
|
||||
|
||||
class GenerationJobCanceledError(Exception):
|
||||
"""Raised when a running worker job has been canceled by the user."""
|
||||
|
||||
|
||||
async def _record_job_event_if_present(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
@@ -99,6 +106,33 @@ async def _record_job_event_if_present(
|
||||
)
|
||||
|
||||
|
||||
async def _stop_if_job_cancel_requested(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
job,
|
||||
story: Story | None = None,
|
||||
) -> None:
|
||||
"""Stop a worker-owned job at the next safe checkpoint after cancellation."""
|
||||
|
||||
if job is None:
|
||||
return
|
||||
|
||||
await db.refresh(job)
|
||||
if job.current_step != "cancel_requested":
|
||||
return
|
||||
|
||||
await finish_generation_job(
|
||||
db,
|
||||
job=job,
|
||||
story=story,
|
||||
status="canceled",
|
||||
current_step="generation_canceled",
|
||||
error_message="Generation canceled by user.",
|
||||
message="Generation job was canceled after a user request.",
|
||||
)
|
||||
raise GenerationJobCanceledError()
|
||||
|
||||
|
||||
def _asset_result_metadata(result: AssetCompletionResult) -> dict:
|
||||
"""Build JSON-safe metadata for asset workflow events."""
|
||||
|
||||
@@ -192,6 +226,7 @@ async def _prepare_generation_context(
|
||||
"has_memory_context": bool(memory_context),
|
||||
},
|
||||
)
|
||||
await _stop_if_job_cancel_requested(db, job=job)
|
||||
return resolved_profile_id, resolved_universe_id, memory_context
|
||||
|
||||
|
||||
@@ -318,6 +353,7 @@ async def _generate_storybook_image_assets(
|
||||
]
|
||||
|
||||
logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages))
|
||||
await _stop_if_job_cancel_requested(db, job=job)
|
||||
await _record_job_event_if_present(
|
||||
db,
|
||||
job=job,
|
||||
@@ -334,6 +370,7 @@ async def _generate_storybook_image_assets(
|
||||
nonlocal cover_failed
|
||||
|
||||
if storybook.cover_prompt and not storybook.cover_url:
|
||||
await _stop_if_job_cancel_requested(db, job=job)
|
||||
try:
|
||||
return await generate_image(
|
||||
storybook.cover_prompt,
|
||||
@@ -350,6 +387,7 @@ async def _generate_storybook_image_assets(
|
||||
if not page.image_prompt or page.image_url:
|
||||
return
|
||||
|
||||
await _stop_if_job_cancel_requested(db, job=job)
|
||||
try:
|
||||
page.image_url = await generate_image(
|
||||
page.image_prompt,
|
||||
@@ -506,6 +544,7 @@ async def _complete_cover_image_asset(
|
||||
|
||||
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
|
||||
await db.commit()
|
||||
await _stop_if_job_cancel_requested(db, job=job, story=story)
|
||||
await _record_job_event_if_present(
|
||||
db,
|
||||
job=job,
|
||||
@@ -517,6 +556,7 @@ async def _complete_cover_image_asset(
|
||||
)
|
||||
|
||||
try:
|
||||
await _stop_if_job_cancel_requested(db, job=job, story=story)
|
||||
image_url = await generate_image(
|
||||
story.cover_prompt,
|
||||
db=db,
|
||||
@@ -605,6 +645,7 @@ async def _complete_storybook_image_assets(
|
||||
|
||||
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
|
||||
await db.commit()
|
||||
await _stop_if_job_cancel_requested(db, job=job, story=story)
|
||||
await _record_job_event_if_present(
|
||||
db,
|
||||
job=job,
|
||||
@@ -620,6 +661,7 @@ async def _complete_storybook_image_assets(
|
||||
completed_pages: list[int] = []
|
||||
|
||||
if story.cover_prompt and not story.image_url:
|
||||
await _stop_if_job_cancel_requested(db, job=job, story=story)
|
||||
try:
|
||||
story.image_url = await generate_image(
|
||||
story.cover_prompt,
|
||||
@@ -658,6 +700,7 @@ async def _complete_storybook_image_assets(
|
||||
if not page.get("image_prompt") or page.get("image_url"):
|
||||
continue
|
||||
|
||||
await _stop_if_job_cancel_requested(db, job=job, story=story)
|
||||
try:
|
||||
page["image_url"] = await generate_image(
|
||||
page["image_prompt"],
|
||||
@@ -800,6 +843,7 @@ async def _complete_audio_asset(
|
||||
|
||||
sync_story_status(story, audio_status=StoryAssetStatus.GENERATING)
|
||||
await db.commit()
|
||||
await _stop_if_job_cancel_requested(db, job=job, story=story)
|
||||
await _record_job_event_if_present(
|
||||
db,
|
||||
job=job,
|
||||
@@ -811,6 +855,7 @@ async def _complete_audio_asset(
|
||||
)
|
||||
|
||||
try:
|
||||
await _stop_if_job_cancel_requested(db, job=job, story=story)
|
||||
audio_data = await text_to_speech(
|
||||
story.story_text,
|
||||
db=db,
|
||||
@@ -933,6 +978,7 @@ async def generate_and_save_story(
|
||||
)
|
||||
|
||||
try:
|
||||
await _stop_if_job_cancel_requested(db, job=job)
|
||||
result = await generate_story_content(
|
||||
input_type=request.type,
|
||||
data=request.data,
|
||||
@@ -955,8 +1001,9 @@ async def generate_and_save_story(
|
||||
message="Story narrative was generated.",
|
||||
metadata={"mode": result.mode, "title": result.title},
|
||||
)
|
||||
await _stop_if_job_cancel_requested(db, job=job)
|
||||
|
||||
return await _persist_text_story_result(
|
||||
story = await _persist_text_story_result(
|
||||
result=result,
|
||||
user_id=user_id,
|
||||
profile_id=profile_id,
|
||||
@@ -964,6 +1011,8 @@ async def generate_and_save_story(
|
||||
db=db,
|
||||
job=job,
|
||||
)
|
||||
await _stop_if_job_cancel_requested(db, job=job, story=story)
|
||||
return story
|
||||
|
||||
|
||||
async def generate_full_story_service(
|
||||
@@ -975,6 +1024,7 @@ async def generate_full_story_service(
|
||||
) -> FullStoryResponse:
|
||||
"""Generate story with parallel image generation."""
|
||||
story = await generate_and_save_story(request, user_id, db, job=job)
|
||||
await _stop_if_job_cancel_requested(db, job=job, story=story)
|
||||
image_url: str | None = None
|
||||
errors: dict[str, str | None] = {}
|
||||
|
||||
@@ -1036,6 +1086,7 @@ async def generate_storybook_service(
|
||||
)
|
||||
|
||||
try:
|
||||
await _stop_if_job_cancel_requested(db, job=job)
|
||||
storybook = await generate_storybook(
|
||||
keywords=request.keywords,
|
||||
page_count=request.page_count,
|
||||
@@ -1060,12 +1111,14 @@ async def generate_storybook_service(
|
||||
"page_count": len(storybook.pages),
|
||||
},
|
||||
)
|
||||
await _stop_if_job_cancel_requested(db, job=job)
|
||||
|
||||
final_cover_url = storybook.cover_url
|
||||
cover_failed = False
|
||||
failed_pages: list[int] = []
|
||||
|
||||
if request.generate_images:
|
||||
await _stop_if_job_cancel_requested(db, job=job)
|
||||
(
|
||||
final_cover_url,
|
||||
cover_failed,
|
||||
@@ -1089,6 +1142,7 @@ async def generate_storybook_service(
|
||||
db=db,
|
||||
job=job,
|
||||
)
|
||||
await _stop_if_job_cancel_requested(db, job=job, story=story)
|
||||
|
||||
response_pages = _storybook_pages_to_response(pages_data)
|
||||
|
||||
@@ -1124,6 +1178,18 @@ async def generate_generation_service(
|
||||
request_payload=request.model_dump(mode="json"),
|
||||
)
|
||||
|
||||
await _dispatch_generation_job(db, job=job)
|
||||
|
||||
return _build_queued_generation_response(request, job_id=job.id)
|
||||
|
||||
|
||||
async def _dispatch_generation_job(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
job: GenerationJob,
|
||||
) -> None:
|
||||
"""Dispatch one accepted generation job to the background worker."""
|
||||
|
||||
try:
|
||||
from app.tasks.generation_workflow import run_generation_workflow_task
|
||||
|
||||
@@ -1144,8 +1210,6 @@ async def generate_generation_service(
|
||||
detail="后台生成任务派发失败,请确认 worker 可用后重试。",
|
||||
) from exc
|
||||
|
||||
return _build_queued_generation_response(request, job_id=job.id)
|
||||
|
||||
|
||||
def _build_queued_generation_response(
|
||||
request: GenerationRequest,
|
||||
@@ -1184,6 +1248,8 @@ async def execute_generation_job_service(
|
||||
db,
|
||||
job=job,
|
||||
)
|
||||
except GenerationJobCanceledError:
|
||||
return _build_canceled_generation_response(job)
|
||||
except HTTPException as exc:
|
||||
await finish_generation_job(
|
||||
db,
|
||||
@@ -1210,6 +1276,24 @@ async def execute_generation_job_service(
|
||||
return response
|
||||
|
||||
|
||||
def _build_canceled_generation_response(job: GenerationJob) -> GenerationResponse:
|
||||
"""Build a compact response for a worker job that ended as canceled."""
|
||||
|
||||
snapshot = job.result_snapshot or {}
|
||||
return GenerationResponse(
|
||||
id=snapshot.get("story_id"),
|
||||
generation_job_id=job.id,
|
||||
title="生成任务已取消",
|
||||
mode="storybook" if job.output_mode == "storybook" else "generated",
|
||||
generation_status=str(snapshot.get("generation_status") or "failed"),
|
||||
text_status=str(snapshot.get("text_status") or "failed"),
|
||||
image_status=str(snapshot.get("image_status") or "not_requested"),
|
||||
audio_status=str(snapshot.get("audio_status") or "not_requested"),
|
||||
last_error=str(snapshot.get("last_error") or "Generation canceled by user."),
|
||||
retryable_assets=list(snapshot.get("retryable_assets") or []),
|
||||
)
|
||||
|
||||
|
||||
async def run_generation_job_service(
|
||||
job_id: str,
|
||||
db: AsyncSession,
|
||||
@@ -1225,6 +1309,46 @@ async def run_generation_job_service(
|
||||
return job
|
||||
|
||||
|
||||
async def retry_generation_job_service(
|
||||
job_id: str,
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
) -> dict:
|
||||
"""Clone one failed/canceled generation job and queue it again."""
|
||||
|
||||
source_job = await get_generation_job_for_user(db, job_id=job_id, user_id=user_id)
|
||||
if not generation_job_can_retry(source_job):
|
||||
raise HTTPException(status_code=409, detail="当前任务还不能重新排队")
|
||||
|
||||
if source_job.story_id is not None:
|
||||
await ensure_no_active_story_generation_job(
|
||||
db,
|
||||
story_id=source_job.story_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
retry_job = await create_generation_job(
|
||||
db,
|
||||
user_id=user_id,
|
||||
output_mode=source_job.output_mode,
|
||||
input_type=source_job.input_type,
|
||||
request_payload=source_job.request_payload or {},
|
||||
story_id=source_job.story_id,
|
||||
)
|
||||
await record_generation_event(
|
||||
db,
|
||||
job=retry_job,
|
||||
story_id=retry_job.story_id,
|
||||
event_type="retry_queued",
|
||||
status="queued",
|
||||
message="Retry job accepted from a previous terminal generation.",
|
||||
metadata={"source_job_id": source_job.id},
|
||||
)
|
||||
await _dispatch_generation_job(db, job=retry_job)
|
||||
await db.refresh(retry_job)
|
||||
return generation_job_to_summary(retry_job)
|
||||
|
||||
|
||||
async def _generate_generation_service_with_job(
|
||||
request: GenerationRequest,
|
||||
user_id: str,
|
||||
|
||||
Reference in New Issue
Block a user