129 lines
4.2 KiB
Python
129 lines
4.2 KiB
Python
"""Generation job tracking tests."""
|
|
|
|
import pytest
|
|
from httpx import ASGITransport, AsyncClient
|
|
from sqlalchemy import select
|
|
|
|
from app.db.database import get_db
|
|
from app.db.models import GenerationJob, GenerationJobEvent
|
|
from app.main import app
|
|
|
|
pytestmark = pytest.mark.asyncio
|
|
|
|
|
|
async def test_unified_generation_records_job_events_and_retryable_assets(
|
|
db_session,
|
|
test_user,
|
|
auth_token,
|
|
mock_text_provider,
|
|
):
|
|
async def override_get_db():
|
|
yield db_session
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
transport = ASGITransport(app=app)
|
|
|
|
try:
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
client.cookies.set("access_token", auth_token)
|
|
|
|
response = await client.post(
|
|
"/api/generations",
|
|
json={
|
|
"output_mode": "story",
|
|
"type": "keywords",
|
|
"data": "小兔子, 森林",
|
|
"generate_images": False,
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["generation_status"] == "narrative_ready"
|
|
assert data["retryable_assets"] == ["image", "audio"]
|
|
|
|
jobs = (
|
|
await db_session.execute(
|
|
select(GenerationJob).where(GenerationJob.user_id == test_user.id)
|
|
)
|
|
).scalars().all()
|
|
assert len(jobs) == 1
|
|
job = jobs[0]
|
|
assert job.story_id == data["id"]
|
|
assert job.output_mode == "story"
|
|
assert job.input_type == "keywords"
|
|
assert job.status == "completed"
|
|
assert job.current_step == "generation_completed"
|
|
assert job.result_snapshot["retryable_assets"] == ["image", "audio"]
|
|
|
|
events = (
|
|
await db_session.execute(
|
|
select(GenerationJobEvent)
|
|
.where(GenerationJobEvent.job_id == job.id)
|
|
.order_by(GenerationJobEvent.id)
|
|
)
|
|
).scalars().all()
|
|
assert [event.event_type for event in events] == [
|
|
"request_accepted",
|
|
"generation_completed",
|
|
]
|
|
finally:
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
async def test_asset_retry_records_job_events_and_updates_retryable_assets(
|
|
db_session,
|
|
test_user,
|
|
auth_token,
|
|
degraded_story_with_text,
|
|
mock_image_provider,
|
|
):
|
|
async def override_get_db():
|
|
yield db_session
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
transport = ASGITransport(app=app)
|
|
|
|
try:
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
client.cookies.set("access_token", auth_token)
|
|
|
|
response = await client.post(
|
|
f"/api/generations/{degraded_story_with_text.id}/retry-assets",
|
|
json={"assets": ["image"]},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["image_status"] == "ready"
|
|
assert data["retryable_assets"] == ["audio"]
|
|
|
|
jobs = (
|
|
await db_session.execute(
|
|
select(GenerationJob).where(
|
|
GenerationJob.story_id == degraded_story_with_text.id,
|
|
GenerationJob.output_mode == "asset_retry",
|
|
)
|
|
)
|
|
).scalars().all()
|
|
assert len(jobs) == 1
|
|
job = jobs[0]
|
|
assert job.status == "completed"
|
|
assert job.current_step == "asset_retry_completed"
|
|
assert job.result_snapshot["retryable_assets"] == ["audio"]
|
|
|
|
events = (
|
|
await db_session.execute(
|
|
select(GenerationJobEvent)
|
|
.where(GenerationJobEvent.job_id == job.id)
|
|
.order_by(GenerationJobEvent.id)
|
|
)
|
|
).scalars().all()
|
|
assert [event.event_type for event in events] == [
|
|
"request_accepted",
|
|
"asset_retry_started",
|
|
"asset_retry_completed",
|
|
]
|
|
finally:
|
|
app.dependency_overrides.clear()
|