Files
dreamweaver/backend/tests/test_generation_jobs.py

806 lines
26 KiB
Python

"""Generation job tracking tests."""
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, patch
import pytest
from httpx import ASGITransport, AsyncClient
from sqlalchemy import select
from app.db.database import get_db
from app.db.models import GenerationJob, GenerationJobEvent
from app.main import app
from app.services.adapters import AdapterConfig
from app.services.adapters.storybook.primary import Storybook, StorybookPage
from app.services.adapters.text.models import StoryOutput
from app.services.generation_jobs import (
create_generation_job,
mark_stale_generation_jobs,
record_generation_event,
)
pytestmark = pytest.mark.asyncio
def build_storybook_output() -> Storybook:
"""Create a reusable mocked storybook payload."""
return Storybook(
title="森林里的发光冒险",
main_character="小兔子露露",
art_style="温暖水彩",
cover_prompt="A glowing forest storybook cover",
pages=[
StorybookPage(
page_number=1,
text="露露第一次走进会发光的森林。",
image_prompt="Lulu entering a glowing forest",
),
StorybookPage(
page_number=2,
text="她遇到了一只会唱歌的萤火虫。",
image_prompt="Lulu meeting a singing firefly",
),
],
)
async def test_unified_generation_records_job_events_and_retryable_assets(
db_session,
test_user,
auth_token,
mock_text_provider,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post(
"/api/generations",
json={
"output_mode": "story",
"type": "keywords",
"data": "小兔子, 森林",
"generate_images": False,
},
)
assert response.status_code == 200
data = response.json()
assert data["generation_status"] == "partial_ready"
assert data["retryable_assets"] == ["image", "audio"]
assert data["generation_job_id"]
jobs = (
await db_session.execute(
select(GenerationJob).where(GenerationJob.user_id == test_user.id)
)
).scalars().all()
assert len(jobs) == 1
job = jobs[0]
assert job.story_id == data["id"]
assert job.output_mode == "story"
assert job.input_type == "keywords"
assert job.status == "completed"
assert job.current_step == "generation_completed"
assert job.result_snapshot["retryable_assets"] == ["image", "audio"]
assert data["generation_job_id"] == job.id
events = (
await db_session.execute(
select(GenerationJobEvent)
.where(GenerationJobEvent.job_id == job.id)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
assert [event.event_type for event in events] == [
"request_accepted",
"context_prepared",
"narrative_generated",
"story_saved",
"generation_completed",
]
assert events[1].event_metadata["has_memory_context"] is False
assert events[2].event_metadata["title"] == "小兔子的冒险"
assert events[3].story_id == data["id"]
detail_response = await client.get(f"/api/generations/jobs/{job.id}")
assert detail_response.status_code == 200
detail = detail_response.json()
assert detail["id"] == job.id
assert detail["story_id"] == data["id"]
assert detail["progress_percent"] == 100
assert detail["progress_label"] == "已完成"
assert detail["is_terminal"] is True
assert [event["event_type"] for event in detail["events"]] == [
"request_accepted",
"context_prepared",
"narrative_generated",
"story_saved",
"generation_completed",
]
list_response = await client.get(f"/api/generations/{data['id']}/jobs")
assert list_response.status_code == 200
job_list = list_response.json()
assert [item["id"] for item in job_list] == [job.id]
assert job_list[0]["progress_percent"] == 100
assert job_list[0]["is_terminal"] is True
finally:
app.dependency_overrides.clear()
async def test_asset_retry_records_job_events_and_updates_retryable_assets(
db_session,
test_user,
auth_token,
degraded_story_with_text,
mock_image_provider,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post(
f"/api/generations/{degraded_story_with_text.id}/retry-assets",
json={"assets": ["image"]},
)
assert response.status_code == 200
data = response.json()
assert data["image_status"] == "ready"
assert data["retryable_assets"] == ["audio"]
jobs = (
await db_session.execute(
select(GenerationJob).where(
GenerationJob.story_id == degraded_story_with_text.id,
GenerationJob.output_mode == "asset_retry",
)
)
).scalars().all()
assert len(jobs) == 1
job = jobs[0]
assert job.status == "completed"
assert job.current_step == "asset_retry_completed"
assert job.result_snapshot["retryable_assets"] == ["audio"]
events = (
await db_session.execute(
select(GenerationJobEvent)
.where(GenerationJobEvent.job_id == job.id)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
assert [event.event_type for event in events] == [
"request_accepted",
"asset_retry_started",
"cover_image_started",
"cover_image_succeeded",
"asset_retry_completed",
]
assert events[3].event_metadata["asset"] == "cover_image"
finally:
app.dependency_overrides.clear()
async def test_storybook_generation_records_page_image_events(
db_session,
auth_token,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
try:
with patch(
"app.services.story_service.generate_storybook",
new_callable=AsyncMock,
) as mock_storybook:
with patch(
"app.services.story_service.generate_image",
new_callable=AsyncMock,
) as mock_image:
mock_storybook.return_value = build_storybook_output()
mock_image.side_effect = [
"https://example.com/storybook-cover.png",
"https://example.com/storybook-page-1.png",
"https://example.com/storybook-page-2.png",
]
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post(
"/api/generations",
json={
"output_mode": "storybook",
"type": "keywords",
"data": "森林, 发光, 友情",
"page_count": 6,
"generate_images": True,
},
)
assert response.status_code == 200
data = response.json()
assert data["mode"] == "storybook"
assert data["image_status"] == "ready"
job = (
await db_session.execute(
select(GenerationJob).where(
GenerationJob.story_id == data["id"],
GenerationJob.output_mode == "storybook",
)
)
).scalar_one()
events = (
await db_session.execute(
select(GenerationJobEvent)
.where(GenerationJobEvent.job_id == job.id)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
assert [event.event_type for event in events] == [
"request_accepted",
"context_prepared",
"narrative_generated",
"storybook_images_started",
"storybook_cover_image_succeeded",
"storybook_page_image_succeeded",
"storybook_page_image_succeeded",
"storybook_images_completed",
"story_saved",
"generation_completed",
]
page_events = [
event
for event in events
if event.event_type == "storybook_page_image_succeeded"
]
assert [event.event_metadata["page_number"] for event in page_events] == [1, 2]
assert events[7].event_metadata["completed_pages"] == [1, 2]
finally:
app.dependency_overrides.clear()
async def test_provider_call_events_record_latency_and_cost(
db_session,
test_user,
):
from app.services import provider_router
mock_result = StoryOutput(
mode="generated",
title="带供应商轨迹的故事",
story_text="一只小鹿学会了复盘。",
cover_prompt_suggestion="A deer with a golden bookmark",
)
class MockAdapter:
estimated_cost = 0.0123
def __init__(self, config):
self.config = config
async def execute(self, **kwargs):
return mock_result
job = await create_generation_job(
db_session,
user_id=test_user.id,
output_mode="story",
input_type="keywords",
request_payload={"data": "小鹿"},
)
with patch.object(
provider_router,
"_get_providers_with_config",
new_callable=AsyncMock,
) as mock_providers:
mock_providers.return_value = [("demo", AdapterConfig(api_key=""), None)]
with patch.object(provider_router.AdapterRegistry, "get", return_value=MockAdapter):
result = await provider_router.generate_story_content(
input_type="keywords",
data="小鹿",
db=db_session,
generation_job=job,
)
assert result == mock_result
events = (
await db_session.execute(
select(GenerationJobEvent)
.where(GenerationJobEvent.job_id == job.id)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
assert [event.event_type for event in events] == [
"request_accepted",
"provider_call_started",
"provider_call_succeeded",
]
provider_event = events[2]
assert provider_event.event_metadata["capability"] == "text"
assert provider_event.event_metadata["adapter"] == "demo"
assert provider_event.event_metadata["strategy"] == "priority"
assert provider_event.event_metadata["latency_ms"] >= 0
assert provider_event.event_metadata["estimated_cost_usd"] == 0.0123
async def test_story_provider_stats_aggregate_job_events(
db_session,
auth_token,
degraded_story_with_text,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
job = await create_generation_job(
db_session,
user_id=degraded_story_with_text.user_id,
output_mode="asset_retry",
input_type="image",
request_payload={"assets": ["image"]},
story_id=degraded_story_with_text.id,
)
await record_generation_event(
db_session,
job=job,
story_id=degraded_story_with_text.id,
event_type="provider_call_succeeded",
status="succeeded",
metadata={
"capability": "image",
"adapter": "demo",
"strategy": "priority",
"latency_ms": 42,
"estimated_cost_usd": 0.01,
},
)
await record_generation_event(
db_session,
job=job,
story_id=degraded_story_with_text.id,
event_type="provider_call_failed",
status="failed",
metadata={
"capability": "image",
"adapter": "cqtai",
"strategy": "priority",
"latency_ms": 120,
"estimated_cost_usd": 0.02,
"error": "timeout",
},
)
transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.get(
f"/api/generations/{degraded_story_with_text.id}/provider-stats"
)
assert response.status_code == 200
data = response.json()
assert data["story_id"] == degraded_story_with_text.id
assert data["total_calls"] == 2
assert data["successful_calls"] == 1
assert data["failed_calls"] == 1
assert data["avg_latency_ms"] == 81.0
assert data["estimated_cost_usd"] == 0.01
assert data["by_provider"] == [
{
"capability": "image",
"adapter": "cqtai",
"call_count": 1,
"success_count": 0,
"failure_count": 1,
"avg_latency_ms": 120.0,
"estimated_cost_usd": 0.0,
},
{
"capability": "image",
"adapter": "demo",
"call_count": 1,
"success_count": 1,
"failure_count": 0,
"avg_latency_ms": 42.0,
"estimated_cost_usd": 0.01,
},
]
finally:
app.dependency_overrides.clear()
async def test_user_provider_analytics_aggregate_across_stories(
db_session,
auth_token,
degraded_story_with_text,
test_story,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
image_job = await create_generation_job(
db_session,
user_id=degraded_story_with_text.user_id,
output_mode="asset_retry",
input_type="image",
request_payload={"assets": ["image"]},
story_id=degraded_story_with_text.id,
)
await record_generation_event(
db_session,
job=image_job,
story_id=degraded_story_with_text.id,
event_type="provider_call_succeeded",
status="succeeded",
metadata={
"capability": "image",
"adapter": "demo",
"strategy": "priority",
"latency_ms": 42,
"estimated_cost_usd": 0.01,
},
)
await record_generation_event(
db_session,
job=image_job,
story_id=degraded_story_with_text.id,
event_type="provider_call_failed",
status="failed",
metadata={
"capability": "image",
"adapter": "cqtai",
"strategy": "priority",
"latency_ms": 120,
"error": "timeout",
},
)
audio_job = await create_generation_job(
db_session,
user_id=test_story.user_id,
output_mode="asset_retry",
input_type="audio",
request_payload={"assets": ["audio"]},
story_id=test_story.id,
)
await record_generation_event(
db_session,
job=audio_job,
story_id=test_story.id,
event_type="provider_call_succeeded",
status="succeeded",
metadata={
"capability": "tts",
"adapter": "edge_tts",
"strategy": "priority",
"latency_ms": 18,
"estimated_cost_usd": 0.003,
},
)
transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.get("/api/generations/provider-analytics")
assert response.status_code == 200
data = response.json()
assert data["job_count"] == 2
assert data["story_count"] == 2
assert data["total_calls"] == 3
assert data["successful_calls"] == 2
assert data["failed_calls"] == 1
assert data["avg_latency_ms"] == 60.0
assert data["estimated_cost_usd"] == 0.013
assert data["failure_reasons"] == [{"reason": "timeout", "count": 1}]
assert data["by_provider"] == [
{
"capability": "image",
"adapter": "cqtai",
"call_count": 1,
"success_count": 0,
"failure_count": 1,
"avg_latency_ms": 120.0,
"estimated_cost_usd": 0.0,
},
{
"capability": "image",
"adapter": "demo",
"call_count": 1,
"success_count": 1,
"failure_count": 0,
"avg_latency_ms": 42.0,
"estimated_cost_usd": 0.01,
},
{
"capability": "tts",
"adapter": "edge_tts",
"call_count": 1,
"success_count": 1,
"failure_count": 0,
"avg_latency_ms": 18.0,
"estimated_cost_usd": 0.003,
},
]
finally:
app.dependency_overrides.clear()
async def test_provider_analytics_support_days_and_capability_filters(
db_session,
auth_token,
degraded_story_with_text,
test_story,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
image_job = await create_generation_job(
db_session,
user_id=degraded_story_with_text.user_id,
output_mode="asset_retry",
input_type="image",
request_payload={"assets": ["image"]},
story_id=degraded_story_with_text.id,
)
old_event = await record_generation_event(
db_session,
job=image_job,
story_id=degraded_story_with_text.id,
event_type="provider_call_failed",
status="failed",
metadata={
"capability": "image",
"adapter": "cqtai",
"strategy": "priority",
"latency_ms": 120,
"error": "timeout",
},
)
old_event.created_at = datetime.now(timezone.utc) - timedelta(days=10)
await db_session.commit()
tts_job = await create_generation_job(
db_session,
user_id=test_story.user_id,
output_mode="asset_retry",
input_type="audio",
request_payload={"assets": ["audio"]},
story_id=test_story.id,
)
await record_generation_event(
db_session,
job=tts_job,
story_id=test_story.id,
event_type="provider_call_succeeded",
status="succeeded",
metadata={
"capability": "tts",
"adapter": "edge_tts",
"strategy": "priority",
"latency_ms": 18,
"estimated_cost_usd": 0.003,
},
)
transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.get("/api/generations/provider-analytics?days=7")
assert response.status_code == 200
data = response.json()
assert data["window_days"] == 7
assert data["total_calls"] == 1
assert data["job_count"] == 1
assert data["story_count"] == 1
assert data["failure_reasons"] == []
response = await client.get(
"/api/generations/provider-analytics?capability=image"
)
assert response.status_code == 200
data = response.json()
assert data["capability"] == "image"
assert data["total_calls"] == 1
assert data["failed_calls"] == 1
assert data["job_count"] == 1
assert data["story_count"] == 1
assert data["failure_reasons"] == [{"reason": "timeout", "count": 1}]
response = await client.get(
f"/api/generations/{degraded_story_with_text.id}/provider-stats?capability=image"
)
assert response.status_code == 200
data = response.json()
assert data["capability"] == "image"
assert data["failure_reasons"] == [{"reason": "timeout", "count": 1}]
finally:
app.dependency_overrides.clear()
async def test_generation_ops_summary_exposes_running_stale_and_recent_failures(
db_session,
auth_token,
degraded_story_with_text,
test_story,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
running_job = await create_generation_job(
db_session,
user_id=test_story.user_id,
output_mode="story",
input_type="keywords",
request_payload={"data": "星星"},
story_id=test_story.id,
)
stale_job = await create_generation_job(
db_session,
user_id=degraded_story_with_text.user_id,
output_mode="asset_generation",
input_type="image",
request_payload={"story_id": degraded_story_with_text.id},
story_id=degraded_story_with_text.id,
)
failed_job = await create_generation_job(
db_session,
user_id=degraded_story_with_text.user_id,
output_mode="asset_retry",
input_type="image",
request_payload={"assets": ["image"]},
story_id=degraded_story_with_text.id,
)
degraded_job = await create_generation_job(
db_session,
user_id=test_story.user_id,
output_mode="storybook",
input_type="keywords",
request_payload={"data": "月亮"},
story_id=test_story.id,
)
stale_job.updated_at = datetime.now(timezone.utc) - timedelta(hours=3)
failed_job.status = "failed"
failed_job.current_step = "asset_retry_failed"
failed_job.error_message = "image timeout"
failed_job.updated_at = datetime.now(timezone.utc) - timedelta(hours=1)
degraded_job.status = "degraded_completed"
degraded_job.current_step = "generation_completed"
degraded_job.updated_at = datetime.now(timezone.utc) - timedelta(minutes=30)
running_job.updated_at = datetime.now(timezone.utc) - timedelta(minutes=10)
await db_session.commit()
transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.get("/api/generations/ops-summary?hours=48")
assert response.status_code == 200
data = response.json()
assert data["window_hours"] == 48
assert data["active_jobs"] == 2
assert data["stale_running_jobs"] == 1
assert data["failed_jobs"] == 1
assert data["degraded_jobs"] == 1
assert data["asset_retry_jobs"] == 2
assert len(data["recent_failures"]) == 1
assert data["recent_failures"][0]["job_id"] == failed_job.id
assert data["recent_failures"][0]["story_title"] == degraded_story_with_text.title
assert data["recent_failures"][0]["failure_label"] == "资源重试失败"
finally:
app.dependency_overrides.clear()
async def test_mark_stale_generation_jobs_marks_old_running_jobs_failed(
db_session,
degraded_story_with_text,
):
stale_job = await create_generation_job(
db_session,
user_id=degraded_story_with_text.user_id,
output_mode="story",
input_type="keywords",
request_payload={"data": "超时任务"},
story_id=degraded_story_with_text.id,
)
stale_job.updated_at = datetime.now(timezone.utc) - timedelta(hours=2)
await db_session.commit()
result = await mark_stale_generation_jobs(db_session, stale_after_minutes=30)
assert result == {"running": 1, "marked_stale": 1, "stale_after_minutes": 30}
refreshed_job = (
await db_session.execute(select(GenerationJob).where(GenerationJob.id == stale_job.id))
).scalar_one()
assert refreshed_job.status == "failed"
assert refreshed_job.current_step == "generation_stale_failed"
assert refreshed_job.error_message == "Generation job exceeded 30 minutes without progress."
events = (
await db_session.execute(
select(GenerationJobEvent)
.where(GenerationJobEvent.job_id == stale_job.id)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
assert events[-1].event_type == "generation_stale_failed"
assert events[-1].event_metadata["stale_after_minutes"] == 30
async def test_retry_assets_rejects_when_story_has_active_job(
db_session,
auth_token,
degraded_story_with_text,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
await create_generation_job(
db_session,
user_id=degraded_story_with_text.user_id,
output_mode="asset_generation",
input_type="image",
request_payload={"story_id": degraded_story_with_text.id},
story_id=degraded_story_with_text.id,
)
transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post(
f"/api/generations/{degraded_story_with_text.id}/retry-assets",
json={"assets": ["image"]},
)
assert response.status_code == 409
assert "已有运行中的任务" in response.json()["detail"]
finally:
app.dependency_overrides.clear()