feat: move unified generation to background worker

This commit is contained in:
2026-04-19 17:29:37 +08:00
parent 5318de670f
commit 6fb128955f
15 changed files with 632 additions and 285 deletions

View File

@@ -18,6 +18,7 @@ from app.services.generation_jobs import (
mark_stale_generation_jobs,
record_generation_event,
)
from app.services.story_service import run_generation_job_service
pytestmark = pytest.mark.asyncio
@@ -45,7 +46,7 @@ def build_storybook_output() -> Storybook:
)
async def test_unified_generation_records_job_events_and_retryable_assets(
async def test_unified_generation_is_queued_then_worker_persists_story_and_events(
db_session,
test_user,
auth_token,
@@ -56,82 +57,108 @@ async def test_unified_generation_records_job_events_and_retryable_assets(
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
task_delay_path = (
"app.tasks.generation_workflow.run_generation_workflow_task.delay"
)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
with patch(task_delay_path) as mock_delay:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post(
"/api/generations",
json={
"output_mode": "story",
"type": "keywords",
"data": "小兔子, 森林",
"generate_images": False,
},
)
assert response.status_code == 200
data = response.json()
assert data["generation_status"] == "partial_ready"
assert data["retryable_assets"] == ["image", "audio"]
assert data["generation_job_id"]
jobs = (
await db_session.execute(
select(GenerationJob).where(GenerationJob.user_id == test_user.id)
response = await client.post(
"/api/generations",
json={
"output_mode": "story",
"type": "keywords",
"data": "小兔子, 森林",
"generate_images": False,
},
)
).scalars().all()
assert len(jobs) == 1
job = jobs[0]
assert job.story_id == data["id"]
assert job.output_mode == "story"
assert job.input_type == "keywords"
assert job.status == "completed"
assert job.current_step == "generation_completed"
assert job.result_snapshot["retryable_assets"] == ["image", "audio"]
assert data["generation_job_id"] == job.id
events = (
await db_session.execute(
select(GenerationJobEvent)
.where(GenerationJobEvent.job_id == job.id)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
assert [event.event_type for event in events] == [
"request_accepted",
"context_prepared",
"narrative_generated",
"story_saved",
"generation_completed",
]
assert events[1].event_metadata["has_memory_context"] is False
assert events[2].event_metadata["title"] == "小兔子的冒险"
assert events[3].story_id == data["id"]
assert response.status_code == 202
data = response.json()
assert data["id"] is None
assert data["generation_status"] == "queued"
assert data["text_status"] == "generating"
assert data["retryable_assets"] == []
assert data["generation_job_id"]
mock_delay.assert_called_once_with(data["generation_job_id"])
detail_response = await client.get(f"/api/generations/jobs/{job.id}")
assert detail_response.status_code == 200
detail = detail_response.json()
assert detail["id"] == job.id
assert detail["story_id"] == data["id"]
assert detail["progress_percent"] == 100
assert detail["progress_label"] == "已完成"
assert detail["is_terminal"] is True
assert [event["event_type"] for event in detail["events"]] == [
"request_accepted",
"context_prepared",
"narrative_generated",
"story_saved",
"generation_completed",
]
jobs = (
await db_session.execute(
select(GenerationJob).where(GenerationJob.user_id == test_user.id)
)
).scalars().all()
assert len(jobs) == 1
job = jobs[0]
assert job.story_id is None
assert job.output_mode == "story"
assert job.input_type == "keywords"
assert job.status == "running"
assert job.current_step == "request_accepted"
assert data["generation_job_id"] == job.id
list_response = await client.get(f"/api/generations/{data['id']}/jobs")
assert list_response.status_code == 200
job_list = list_response.json()
assert [item["id"] for item in job_list] == [job.id]
assert job_list[0]["progress_percent"] == 100
assert job_list[0]["is_terminal"] is True
await run_generation_job_service(job.id, db_session)
job = (
await db_session.execute(
select(GenerationJob).where(GenerationJob.id == job.id)
)
).scalar_one()
assert job.story_id is not None
assert job.status == "completed"
assert job.current_step == "generation_completed"
assert job.result_snapshot["retryable_assets"] == ["image", "audio"]
events = (
await db_session.execute(
select(GenerationJobEvent)
.where(GenerationJobEvent.job_id == job.id)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
assert [event.event_type for event in events] == [
"request_accepted",
"worker_started",
"context_prepared",
"narrative_generated",
"story_saved",
"generation_completed",
]
assert events[2].event_metadata["has_memory_context"] is False
assert events[3].event_metadata["title"] == "小兔子的冒险"
assert events[4].story_id == job.story_id
detail_response = await client.get(f"/api/generations/jobs/{job.id}")
assert detail_response.status_code == 200
detail = detail_response.json()
assert detail["id"] == job.id
assert detail["story_id"] == job.story_id
assert detail["progress_percent"] == 100
assert detail["progress_label"] == "已完成"
assert detail["is_terminal"] is True
assert [event["event_type"] for event in detail["events"]] == [
"request_accepted",
"worker_started",
"context_prepared",
"narrative_generated",
"story_saved",
"generation_completed",
]
story_response = await client.get(f"/api/generations/{job.story_id}")
assert story_response.status_code == 200
story_data = story_response.json()
assert story_data["generation_status"] == "partial_ready"
assert story_data["retryable_assets"] == ["image", "audio"]
list_response = await client.get(f"/api/generations/{job.story_id}/jobs")
assert list_response.status_code == 200
job_list = list_response.json()
assert [item["id"] for item in job_list] == [job.id]
assert job_list[0]["progress_percent"] == 100
assert job_list[0]["is_terminal"] is True
finally:
app.dependency_overrides.clear()
@@ -196,7 +223,7 @@ async def test_asset_retry_records_job_events_and_updates_retryable_assets(
app.dependency_overrides.clear()
async def test_storybook_generation_records_page_image_events(
async def test_storybook_generation_is_queued_then_worker_records_page_image_events(
db_session,
auth_token,
):
@@ -222,61 +249,78 @@ async def test_storybook_generation_records_page_image_events(
"https://example.com/storybook-page-2.png",
]
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post(
"/api/generations",
json={
"output_mode": "storybook",
"type": "keywords",
"data": "森林, 发光, 友情",
"page_count": 6,
"generate_images": True,
},
)
assert response.status_code == 200
data = response.json()
assert data["mode"] == "storybook"
assert data["image_status"] == "ready"
job = (
await db_session.execute(
select(GenerationJob).where(
GenerationJob.story_id == data["id"],
GenerationJob.output_mode == "storybook",
task_delay_path = (
"app.tasks.generation_workflow.run_generation_workflow_task.delay"
)
)
).scalar_one()
with patch(task_delay_path) as mock_delay:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
events = (
await db_session.execute(
select(GenerationJobEvent)
.where(GenerationJobEvent.job_id == job.id)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
response = await client.post(
"/api/generations",
json={
"output_mode": "storybook",
"type": "keywords",
"data": "森林, 发光, 友情",
"page_count": 6,
"generate_images": True,
},
)
assert [event.event_type for event in events] == [
"request_accepted",
"context_prepared",
"narrative_generated",
"storybook_images_started",
"storybook_cover_image_succeeded",
"storybook_page_image_succeeded",
"storybook_page_image_succeeded",
"storybook_images_completed",
"story_saved",
"generation_completed",
]
page_events = [
event
for event in events
if event.event_type == "storybook_page_image_succeeded"
]
assert [event.event_metadata["page_number"] for event in page_events] == [1, 2]
assert events[7].event_metadata["completed_pages"] == [1, 2]
assert response.status_code == 202
data = response.json()
assert data["id"] is None
assert data["mode"] == "storybook"
assert data["generation_status"] == "queued"
assert data["text_status"] == "generating"
mock_delay.assert_called_once_with(data["generation_job_id"])
job = (
await db_session.execute(
select(GenerationJob).where(
GenerationJob.id == data["generation_job_id"],
)
)
).scalar_one()
await run_generation_job_service(job.id, db_session)
job = (
await db_session.execute(
select(GenerationJob).where(
GenerationJob.id == job.id,
)
)
).scalar_one()
assert job.story_id is not None
assert job.status == "completed"
events = (
await db_session.execute(
select(GenerationJobEvent)
.where(GenerationJobEvent.job_id == job.id)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
assert [event.event_type for event in events] == [
"request_accepted",
"worker_started",
"context_prepared",
"narrative_generated",
"storybook_images_started",
"storybook_cover_image_succeeded",
"storybook_page_image_succeeded",
"storybook_page_image_succeeded",
"storybook_images_completed",
"story_saved",
"generation_completed",
]
page_events = [
event
for event in events
if event.event_type == "storybook_page_image_succeeded"
]
assert [event.event_metadata["page_number"] for event in page_events] == [1, 2]
assert events[8].event_metadata["completed_pages"] == [1, 2]
finally:
app.dependency_overrides.clear()