feat: track generation jobs

This commit is contained in:
2026-04-18 16:29:22 +08:00
parent 16fafe0fe0
commit 96dfc677e2
18 changed files with 709 additions and 71 deletions

View File

@@ -0,0 +1,128 @@
"""Generation job tracking tests."""
import pytest
from httpx import ASGITransport, AsyncClient
from sqlalchemy import select
from app.db.database import get_db
from app.db.models import GenerationJob, GenerationJobEvent
from app.main import app
pytestmark = pytest.mark.asyncio
async def test_unified_generation_records_job_events_and_retryable_assets(
db_session,
test_user,
auth_token,
mock_text_provider,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post(
"/api/generations",
json={
"output_mode": "story",
"type": "keywords",
"data": "小兔子, 森林",
"generate_images": False,
},
)
assert response.status_code == 200
data = response.json()
assert data["generation_status"] == "narrative_ready"
assert data["retryable_assets"] == ["image", "audio"]
jobs = (
await db_session.execute(
select(GenerationJob).where(GenerationJob.user_id == test_user.id)
)
).scalars().all()
assert len(jobs) == 1
job = jobs[0]
assert job.story_id == data["id"]
assert job.output_mode == "story"
assert job.input_type == "keywords"
assert job.status == "completed"
assert job.current_step == "generation_completed"
assert job.result_snapshot["retryable_assets"] == ["image", "audio"]
events = (
await db_session.execute(
select(GenerationJobEvent)
.where(GenerationJobEvent.job_id == job.id)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
assert [event.event_type for event in events] == [
"request_accepted",
"generation_completed",
]
finally:
app.dependency_overrides.clear()
async def test_asset_retry_records_job_events_and_updates_retryable_assets(
db_session,
test_user,
auth_token,
degraded_story_with_text,
mock_image_provider,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post(
f"/api/generations/{degraded_story_with_text.id}/retry-assets",
json={"assets": ["image"]},
)
assert response.status_code == 200
data = response.json()
assert data["image_status"] == "ready"
assert data["retryable_assets"] == ["audio"]
jobs = (
await db_session.execute(
select(GenerationJob).where(
GenerationJob.story_id == degraded_story_with_text.id,
GenerationJob.output_mode == "asset_retry",
)
)
).scalars().all()
assert len(jobs) == 1
job = jobs[0]
assert job.status == "completed"
assert job.current_step == "asset_retry_completed"
assert job.result_snapshot["retryable_assets"] == ["audio"]
events = (
await db_session.execute(
select(GenerationJobEvent)
.where(GenerationJobEvent.job_id == job.id)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
assert [event.event_type for event in events] == [
"request_accepted",
"asset_retry_started",
"asset_retry_completed",
]
finally:
app.dependency_overrides.clear()