feat: move unified generation to background worker

This commit is contained in:
2026-04-19 17:29:37 +08:00
parent 5318de670f
commit 6fb128955f
15 changed files with 632 additions and 285 deletions

View File

@@ -67,13 +67,13 @@ def _mark_legacy_generation_endpoint(response: Response, successor: str) -> None
response.headers.update(_legacy_generation_headers(successor))
@router.post("/generations", response_model=GenerationResponse)
@router.post("/generations", response_model=GenerationResponse, status_code=202)
async def create_generation(
request: GenerationRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Create a story or storybook through the unified generation workflow."""
"""Accept one story/storybook generation request for background execution."""
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
return await story_service.generate_generation_service(request, user.id, db)

View File

@@ -117,9 +117,9 @@ class StorybookResponse(StoryStatusMixin):
class GenerationResponse(StoryStatusMixin):
"""Unified generation response for the target workflow API."""
id: int
id: int | None = None
generation_job_id: str | None = None
title: str
title: str | None = None
mode: str
story_text: str | None = None
pages: list[StorybookPageResponse] | None = None

View File

@@ -6,7 +6,7 @@ from datetime import datetime, timedelta, timezone
from typing import Any
from fastapi import HTTPException
from sqlalchemy import desc, select
from sqlalchemy import desc, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
@@ -59,6 +59,7 @@ def _job_progress(job: GenerationJob) -> dict[str, Any]:
progress_map: dict[str, tuple[int, str]] = {
"request_accepted": (5, "已接收请求"),
"worker_started": (12, "后台任务已开始"),
"context_prepared": (20, "上下文已准备"),
"narrative_generated": (45, "正文已生成"),
"story_saved": (60, "主记录已保存"),
@@ -66,8 +67,18 @@ def _job_progress(job: GenerationJob) -> dict[str, Any]:
"provider_call_succeeded": (72, "Provider 调用成功"),
"provider_call_failed": (72, "Provider 调用失败,尝试恢复"),
"cover_image_started": (75, "封面生成中"),
"cover_image_succeeded": (88, "封面已生成"),
"cover_image_failed": (88, "封面生成失败"),
"storybook_images_started": (75, "绘本插图生成中"),
"storybook_cover_image_succeeded": (82, "绘本封面已生成"),
"storybook_cover_image_failed": (82, "绘本封面生成失败"),
"storybook_page_image_succeeded": (86, "分页插图已生成"),
"storybook_page_image_failed": (86, "分页插图生成失败"),
"storybook_images_completed": (92, "绘本插图已完成"),
"audio_started": (75, "音频生成中"),
"audio_cache_hit": (88, "音频缓存已复用"),
"audio_succeeded": (88, "音频已生成"),
"audio_failed": (88, "音频生成失败"),
"asset_retry_started": (25, "资源重试中"),
"postprocessing_queued": (90, "后处理已排队"),
"asset_generation_completed": (100, "资源已完成"),
@@ -155,6 +166,10 @@ async def record_generation_event(
) -> GenerationJobEvent:
"""Append one event to an existing generation job."""
job.current_step = event_type
if story_id is not None:
job.story_id = story_id
event = GenerationJobEvent(
job_id=job.id,
story_id=story_id if story_id is not None else job.story_id,
@@ -169,6 +184,42 @@ async def record_generation_event(
return event
async def claim_generation_job_for_worker(
db: AsyncSession,
*,
job_id: str,
) -> GenerationJob | None:
"""Claim one queued generation job for worker execution once."""
claim_result = await db.execute(
update(GenerationJob)
.where(
GenerationJob.id == job_id,
GenerationJob.status == "running",
GenerationJob.current_step == "request_accepted",
)
.values(current_step="worker_started")
)
await db.commit()
if not claim_result.rowcount:
return None
result = await db.execute(select(GenerationJob).where(GenerationJob.id == job_id))
job = result.scalar_one_or_none()
if job is None:
return None
await record_generation_event(
db,
job=job,
event_type="worker_started",
status="running",
message="Generation worker started processing the accepted request.",
)
return job
async def finish_generation_job(
db: AsyncSession,
*,

View File

@@ -12,7 +12,7 @@ from sqlalchemy.orm import joinedload
from app.core.config import settings
from app.core.logging import get_logger
from app.db.models import ChildProfile, Story, StoryUniverse
from app.db.models import ChildProfile, GenerationJob, Story, StoryUniverse
from app.schemas.story_schemas import (
AchievementItem,
FullStoryResponse,
@@ -33,6 +33,7 @@ from app.services.audio_storage import (
write_story_audio_cache,
)
from app.services.generation_jobs import (
claim_generation_job_for_worker,
create_generation_job,
ensure_no_active_story_generation_job,
finish_generation_job,
@@ -1113,7 +1114,7 @@ async def generate_generation_service(
user_id: str,
db: AsyncSession,
) -> GenerationResponse:
"""Unified generation workflow entry point for stories and storybooks."""
"""Queue one unified generation workflow for background execution."""
job = await create_generation_job(
db,
@@ -1124,7 +1125,65 @@ async def generate_generation_service(
)
try:
response = await _generate_generation_service_with_job(request, user_id, db, job=job)
from app.tasks.generation_workflow import run_generation_workflow_task
run_generation_workflow_task.delay(job.id)
except Exception as exc:
await finish_generation_job(
db,
job=job,
story=None,
status="failed",
current_step="generation_failed",
error_message="Background generation dispatch failed.",
message="Generation failed before the worker could start processing the job.",
metadata={"dispatch_error": str(exc)},
)
raise HTTPException(
status_code=503,
detail="后台生成任务派发失败,请确认 worker 可用后重试。",
) from exc
return _build_queued_generation_response(request, job_id=job.id)
def _build_queued_generation_response(
request: GenerationRequest,
*,
job_id: str,
) -> GenerationResponse:
"""Build the immediate API response after a generation job is accepted."""
return GenerationResponse(
id=None,
generation_job_id=job_id,
title="生成任务已提交",
mode="storybook" if request.output_mode == "storybook" else "generated",
generation_status="queued",
text_status="generating",
image_status="not_requested",
audio_status="not_requested",
last_error=None,
retryable_assets=[],
child_profile_id=request.child_profile_id,
universe_id=request.universe_id,
)
async def execute_generation_job_service(
job: GenerationJob,
db: AsyncSession,
) -> GenerationResponse:
"""Execute one previously accepted generation job inside the worker."""
try:
request = GenerationRequest.model_validate(job.request_payload or {})
response = await _generate_generation_service_with_job(
request,
job.user_id,
db,
job=job,
)
except HTTPException as exc:
await finish_generation_job(
db,
@@ -1151,6 +1210,21 @@ async def generate_generation_service(
return response
async def run_generation_job_service(
job_id: str,
db: AsyncSession,
) -> GenerationJob | None:
"""Claim and execute one generation job from the background queue."""
job = await claim_generation_job_for_worker(db, job_id=job_id)
if job is None:
logger.info("generation_job_execution_skipped", job_id=job_id)
return None
await execute_generation_job_service(job, db)
return job
async def _generate_generation_service_with_job(
request: GenerationRequest,
user_id: str,

View File

@@ -0,0 +1,38 @@
"""Background execution for unified generation workflows."""
import asyncio
from app.core.celery_app import celery_app
from app.core.logging import get_logger
from app.db.database import _get_session_factory
from app.services.story_service import run_generation_job_service
logger = get_logger(__name__)
@celery_app.task
def run_generation_workflow_task(job_id: str):
"""Execute one accepted generation job in the Celery worker."""
logger.info("generation_workflow_task_started", job_id=job_id)
async def _run():
session_factory = _get_session_factory()
async with session_factory() as session:
return await run_generation_job_service(job_id, session)
try:
result = asyncio.run(_run())
logger.info(
"generation_workflow_task_completed",
job_id=job_id,
executed=bool(result),
)
return {"job_id": job_id, "executed": bool(result)}
except Exception as exc:
logger.error(
"generation_workflow_task_failed",
job_id=job_id,
error=str(exc),
)
raise

View File

@@ -18,6 +18,7 @@ from app.services.generation_jobs import (
mark_stale_generation_jobs,
record_generation_event,
)
from app.services.story_service import run_generation_job_service
pytestmark = pytest.mark.asyncio
@@ -45,7 +46,7 @@ def build_storybook_output() -> Storybook:
)
async def test_unified_generation_records_job_events_and_retryable_assets(
async def test_unified_generation_is_queued_then_worker_persists_story_and_events(
db_session,
test_user,
auth_token,
@@ -56,82 +57,108 @@ async def test_unified_generation_records_job_events_and_retryable_assets(
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
task_delay_path = (
"app.tasks.generation_workflow.run_generation_workflow_task.delay"
)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
with patch(task_delay_path) as mock_delay:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post(
"/api/generations",
json={
"output_mode": "story",
"type": "keywords",
"data": "小兔子, 森林",
"generate_images": False,
},
)
assert response.status_code == 200
data = response.json()
assert data["generation_status"] == "partial_ready"
assert data["retryable_assets"] == ["image", "audio"]
assert data["generation_job_id"]
jobs = (
await db_session.execute(
select(GenerationJob).where(GenerationJob.user_id == test_user.id)
response = await client.post(
"/api/generations",
json={
"output_mode": "story",
"type": "keywords",
"data": "小兔子, 森林",
"generate_images": False,
},
)
).scalars().all()
assert len(jobs) == 1
job = jobs[0]
assert job.story_id == data["id"]
assert job.output_mode == "story"
assert job.input_type == "keywords"
assert job.status == "completed"
assert job.current_step == "generation_completed"
assert job.result_snapshot["retryable_assets"] == ["image", "audio"]
assert data["generation_job_id"] == job.id
events = (
await db_session.execute(
select(GenerationJobEvent)
.where(GenerationJobEvent.job_id == job.id)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
assert [event.event_type for event in events] == [
"request_accepted",
"context_prepared",
"narrative_generated",
"story_saved",
"generation_completed",
]
assert events[1].event_metadata["has_memory_context"] is False
assert events[2].event_metadata["title"] == "小兔子的冒险"
assert events[3].story_id == data["id"]
assert response.status_code == 202
data = response.json()
assert data["id"] is None
assert data["generation_status"] == "queued"
assert data["text_status"] == "generating"
assert data["retryable_assets"] == []
assert data["generation_job_id"]
mock_delay.assert_called_once_with(data["generation_job_id"])
detail_response = await client.get(f"/api/generations/jobs/{job.id}")
assert detail_response.status_code == 200
detail = detail_response.json()
assert detail["id"] == job.id
assert detail["story_id"] == data["id"]
assert detail["progress_percent"] == 100
assert detail["progress_label"] == "已完成"
assert detail["is_terminal"] is True
assert [event["event_type"] for event in detail["events"]] == [
"request_accepted",
"context_prepared",
"narrative_generated",
"story_saved",
"generation_completed",
]
jobs = (
await db_session.execute(
select(GenerationJob).where(GenerationJob.user_id == test_user.id)
)
).scalars().all()
assert len(jobs) == 1
job = jobs[0]
assert job.story_id is None
assert job.output_mode == "story"
assert job.input_type == "keywords"
assert job.status == "running"
assert job.current_step == "request_accepted"
assert data["generation_job_id"] == job.id
list_response = await client.get(f"/api/generations/{data['id']}/jobs")
assert list_response.status_code == 200
job_list = list_response.json()
assert [item["id"] for item in job_list] == [job.id]
assert job_list[0]["progress_percent"] == 100
assert job_list[0]["is_terminal"] is True
await run_generation_job_service(job.id, db_session)
job = (
await db_session.execute(
select(GenerationJob).where(GenerationJob.id == job.id)
)
).scalar_one()
assert job.story_id is not None
assert job.status == "completed"
assert job.current_step == "generation_completed"
assert job.result_snapshot["retryable_assets"] == ["image", "audio"]
events = (
await db_session.execute(
select(GenerationJobEvent)
.where(GenerationJobEvent.job_id == job.id)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
assert [event.event_type for event in events] == [
"request_accepted",
"worker_started",
"context_prepared",
"narrative_generated",
"story_saved",
"generation_completed",
]
assert events[2].event_metadata["has_memory_context"] is False
assert events[3].event_metadata["title"] == "小兔子的冒险"
assert events[4].story_id == job.story_id
detail_response = await client.get(f"/api/generations/jobs/{job.id}")
assert detail_response.status_code == 200
detail = detail_response.json()
assert detail["id"] == job.id
assert detail["story_id"] == job.story_id
assert detail["progress_percent"] == 100
assert detail["progress_label"] == "已完成"
assert detail["is_terminal"] is True
assert [event["event_type"] for event in detail["events"]] == [
"request_accepted",
"worker_started",
"context_prepared",
"narrative_generated",
"story_saved",
"generation_completed",
]
story_response = await client.get(f"/api/generations/{job.story_id}")
assert story_response.status_code == 200
story_data = story_response.json()
assert story_data["generation_status"] == "partial_ready"
assert story_data["retryable_assets"] == ["image", "audio"]
list_response = await client.get(f"/api/generations/{job.story_id}/jobs")
assert list_response.status_code == 200
job_list = list_response.json()
assert [item["id"] for item in job_list] == [job.id]
assert job_list[0]["progress_percent"] == 100
assert job_list[0]["is_terminal"] is True
finally:
app.dependency_overrides.clear()
@@ -196,7 +223,7 @@ async def test_asset_retry_records_job_events_and_updates_retryable_assets(
app.dependency_overrides.clear()
async def test_storybook_generation_records_page_image_events(
async def test_storybook_generation_is_queued_then_worker_records_page_image_events(
db_session,
auth_token,
):
@@ -222,61 +249,78 @@ async def test_storybook_generation_records_page_image_events(
"https://example.com/storybook-page-2.png",
]
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post(
"/api/generations",
json={
"output_mode": "storybook",
"type": "keywords",
"data": "森林, 发光, 友情",
"page_count": 6,
"generate_images": True,
},
)
assert response.status_code == 200
data = response.json()
assert data["mode"] == "storybook"
assert data["image_status"] == "ready"
job = (
await db_session.execute(
select(GenerationJob).where(
GenerationJob.story_id == data["id"],
GenerationJob.output_mode == "storybook",
task_delay_path = (
"app.tasks.generation_workflow.run_generation_workflow_task.delay"
)
)
).scalar_one()
with patch(task_delay_path) as mock_delay:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
events = (
await db_session.execute(
select(GenerationJobEvent)
.where(GenerationJobEvent.job_id == job.id)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
response = await client.post(
"/api/generations",
json={
"output_mode": "storybook",
"type": "keywords",
"data": "森林, 发光, 友情",
"page_count": 6,
"generate_images": True,
},
)
assert [event.event_type for event in events] == [
"request_accepted",
"context_prepared",
"narrative_generated",
"storybook_images_started",
"storybook_cover_image_succeeded",
"storybook_page_image_succeeded",
"storybook_page_image_succeeded",
"storybook_images_completed",
"story_saved",
"generation_completed",
]
page_events = [
event
for event in events
if event.event_type == "storybook_page_image_succeeded"
]
assert [event.event_metadata["page_number"] for event in page_events] == [1, 2]
assert events[7].event_metadata["completed_pages"] == [1, 2]
assert response.status_code == 202
data = response.json()
assert data["id"] is None
assert data["mode"] == "storybook"
assert data["generation_status"] == "queued"
assert data["text_status"] == "generating"
mock_delay.assert_called_once_with(data["generation_job_id"])
job = (
await db_session.execute(
select(GenerationJob).where(
GenerationJob.id == data["generation_job_id"],
)
)
).scalar_one()
await run_generation_job_service(job.id, db_session)
job = (
await db_session.execute(
select(GenerationJob).where(
GenerationJob.id == job.id,
)
)
).scalar_one()
assert job.story_id is not None
assert job.status == "completed"
events = (
await db_session.execute(
select(GenerationJobEvent)
.where(GenerationJobEvent.job_id == job.id)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
assert [event.event_type for event in events] == [
"request_accepted",
"worker_started",
"context_prepared",
"narrative_generated",
"storybook_images_started",
"storybook_cover_image_succeeded",
"storybook_page_image_succeeded",
"storybook_page_image_succeeded",
"storybook_images_completed",
"story_saved",
"generation_completed",
]
page_events = [
event
for event in events
if event.event_type == "storybook_page_image_succeeded"
]
assert [event.event_metadata["page_number"] for event in page_events] == [1, 2]
assert events[8].event_metadata["completed_pages"] == [1, 2]
finally:
app.dependency_overrides.clear()

View File

@@ -430,6 +430,8 @@ class TestGenerateFull:
class TestUnifiedGenerations:
"""Tests for the target unified generation API."""
TASK_DELAY_PATH = "app.tasks.generation_workflow.run_generation_workflow_task.delay"
def test_create_generation_without_auth(self, client: TestClient):
response = client.post(
"/api/generations",
@@ -443,60 +445,64 @@ class TestUnifiedGenerations:
mock_text_provider,
mock_image_provider,
):
response = auth_client.post(
"/api/generations",
json={
"output_mode": "story",
"type": "keywords",
"data": "小兔子, 森林, 勇气",
"generate_images": True,
},
)
with patch(self.TASK_DELAY_PATH) as mock_delay:
response = auth_client.post(
"/api/generations",
json={
"output_mode": "story",
"type": "keywords",
"data": "小兔子, 森林, 勇气",
"generate_images": True,
},
)
assert response.status_code == 200
assert response.status_code == 202
assert "Deprecation" not in response.headers
data = response.json()
assert data["id"] is not None
assert data["id"] is None
assert data["mode"] == "generated"
assert data["story_text"] == "从前有一只小兔子。"
assert data["image_url"] == "https://example.com/image.png"
assert data["cover_url"] == "https://example.com/image.png"
assert data["story_text"] is None
assert data["image_url"] is None
assert data["cover_url"] is None
assert data["pages"] is None
assert data["generation_status"] == "partial_ready"
assert data["image_status"] == "ready"
assert data["generation_status"] == "queued"
assert data["text_status"] == "generating"
assert data["image_status"] == "not_requested"
assert data["audio_status"] == "not_requested"
assert data["errors"] == {}
mock_delay.assert_called_once_with(data["generation_job_id"])
def test_create_story_generation_without_assets(
self,
auth_client: TestClient,
mock_text_provider,
):
response = auth_client.post(
"/api/generations",
json={
"output_mode": "story",
"type": "keywords",
"data": "小兔子, 森林",
"generate_images": False,
},
)
with patch(self.TASK_DELAY_PATH) as mock_delay:
response = auth_client.post(
"/api/generations",
json={
"output_mode": "story",
"type": "keywords",
"data": "小兔子, 森林",
"generate_images": False,
},
)
assert response.status_code == 200
assert response.status_code == 202
data = response.json()
assert data["mode"] == "generated"
assert data["image_url"] is None
assert data["generation_status"] == "partial_ready"
assert data["generation_status"] == "queued"
assert data["text_status"] == "generating"
assert data["image_status"] == "not_requested"
mock_delay.assert_called_once_with(data["generation_job_id"])
def test_create_story_generation_image_failure(
self,
auth_client: TestClient,
mock_text_provider,
):
with patch("app.services.story_service.generate_image", new_callable=AsyncMock) as mock_img:
mock_img.side_effect = Exception("Image API error")
with patch(self.TASK_DELAY_PATH) as mock_delay:
response = auth_client.post(
"/api/generations",
json={
@@ -507,55 +513,45 @@ class TestUnifiedGenerations:
},
)
assert response.status_code == 200
assert response.status_code == 202
data = response.json()
assert data["image_url"] is None
assert data["generation_status"] == "degraded_completed"
assert data["image_status"] == "failed"
assert data["generation_status"] == "queued"
assert data["text_status"] == "generating"
assert data["image_status"] == "not_requested"
assert data["audio_status"] == "not_requested"
assert "Image API error" in data["errors"]["image"]
assert "Image API error" in data["last_error"]
assert data["errors"] == {}
assert data["last_error"] is None
mock_delay.assert_called_once_with(data["generation_job_id"])
def test_create_storybook_generation_success(self, auth_client: TestClient):
with patch(
"app.services.story_service.generate_storybook",
new_callable=AsyncMock,
) as mock_storybook:
with patch(
"app.services.story_service.generate_image",
new_callable=AsyncMock,
) as mock_image:
mock_storybook.return_value = build_storybook_output()
mock_image.side_effect = [
"https://example.com/storybook-cover.png",
"https://example.com/storybook-page-1.png",
"https://example.com/storybook-page-2.png",
]
with patch(self.TASK_DELAY_PATH) as mock_delay:
response = auth_client.post(
"/api/generations",
json={
"output_mode": "storybook",
"type": "keywords",
"data": "森林, 发光, 友情",
"page_count": 6,
"generate_images": True,
},
)
response = auth_client.post(
"/api/generations",
json={
"output_mode": "storybook",
"type": "keywords",
"data": "森林, 发光, 友情",
"page_count": 6,
"generate_images": True,
},
)
assert response.status_code == 200
assert response.status_code == 202
data = response.json()
assert data["id"] is not None
assert data["id"] is None
assert data["mode"] == "storybook"
assert data["story_text"] is None
assert len(data["pages"]) == 2
assert data["cover_url"] == "https://example.com/storybook-cover.png"
assert data["image_url"] == "https://example.com/storybook-cover.png"
assert data["main_character"] == "小兔子露露"
assert data["art_style"] == "温暖水彩"
assert data["generation_status"] == "completed"
assert data["image_status"] == "ready"
assert data["pages"] is None
assert data["cover_url"] is None
assert data["image_url"] is None
assert data["main_character"] is None
assert data["art_style"] is None
assert data["generation_status"] == "queued"
assert data["text_status"] == "generating"
assert data["image_status"] == "not_requested"
assert data["audio_status"] == "not_requested"
mock_delay.assert_called_once_with(data["generation_job_id"])
def test_get_generation_alias(self, auth_client: TestClient, test_story):
response = auth_client.get(f"/api/generations/{test_story.id}")