Files
dreamweaver/backend/tests/test_generation_jobs.py

554 lines
18 KiB
Python

"""Generation job tracking tests."""
from unittest.mock import AsyncMock, patch
import pytest
from httpx import ASGITransport, AsyncClient
from sqlalchemy import select
from app.db.database import get_db
from app.db.models import GenerationJob, GenerationJobEvent
from app.main import app
from app.services.adapters import AdapterConfig
from app.services.adapters.storybook.primary import Storybook, StorybookPage
from app.services.adapters.text.models import StoryOutput
from app.services.generation_jobs import create_generation_job, record_generation_event
pytestmark = pytest.mark.asyncio
def build_storybook_output() -> Storybook:
"""Create a reusable mocked storybook payload."""
return Storybook(
title="森林里的发光冒险",
main_character="小兔子露露",
art_style="温暖水彩",
cover_prompt="A glowing forest storybook cover",
pages=[
StorybookPage(
page_number=1,
text="露露第一次走进会发光的森林。",
image_prompt="Lulu entering a glowing forest",
),
StorybookPage(
page_number=2,
text="她遇到了一只会唱歌的萤火虫。",
image_prompt="Lulu meeting a singing firefly",
),
],
)
async def test_unified_generation_records_job_events_and_retryable_assets(
db_session,
test_user,
auth_token,
mock_text_provider,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post(
"/api/generations",
json={
"output_mode": "story",
"type": "keywords",
"data": "小兔子, 森林",
"generate_images": False,
},
)
assert response.status_code == 200
data = response.json()
assert data["generation_status"] == "partial_ready"
assert data["retryable_assets"] == ["image", "audio"]
assert data["generation_job_id"]
jobs = (
await db_session.execute(
select(GenerationJob).where(GenerationJob.user_id == test_user.id)
)
).scalars().all()
assert len(jobs) == 1
job = jobs[0]
assert job.story_id == data["id"]
assert job.output_mode == "story"
assert job.input_type == "keywords"
assert job.status == "completed"
assert job.current_step == "generation_completed"
assert job.result_snapshot["retryable_assets"] == ["image", "audio"]
assert data["generation_job_id"] == job.id
events = (
await db_session.execute(
select(GenerationJobEvent)
.where(GenerationJobEvent.job_id == job.id)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
assert [event.event_type for event in events] == [
"request_accepted",
"context_prepared",
"narrative_generated",
"story_saved",
"generation_completed",
]
assert events[1].event_metadata["has_memory_context"] is False
assert events[2].event_metadata["title"] == "小兔子的冒险"
assert events[3].story_id == data["id"]
detail_response = await client.get(f"/api/generations/jobs/{job.id}")
assert detail_response.status_code == 200
detail = detail_response.json()
assert detail["id"] == job.id
assert detail["story_id"] == data["id"]
assert detail["progress_percent"] == 100
assert detail["progress_label"] == "已完成"
assert detail["is_terminal"] is True
assert [event["event_type"] for event in detail["events"]] == [
"request_accepted",
"context_prepared",
"narrative_generated",
"story_saved",
"generation_completed",
]
list_response = await client.get(f"/api/generations/{data['id']}/jobs")
assert list_response.status_code == 200
job_list = list_response.json()
assert [item["id"] for item in job_list] == [job.id]
assert job_list[0]["progress_percent"] == 100
assert job_list[0]["is_terminal"] is True
finally:
app.dependency_overrides.clear()
async def test_asset_retry_records_job_events_and_updates_retryable_assets(
db_session,
test_user,
auth_token,
degraded_story_with_text,
mock_image_provider,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post(
f"/api/generations/{degraded_story_with_text.id}/retry-assets",
json={"assets": ["image"]},
)
assert response.status_code == 200
data = response.json()
assert data["image_status"] == "ready"
assert data["retryable_assets"] == ["audio"]
jobs = (
await db_session.execute(
select(GenerationJob).where(
GenerationJob.story_id == degraded_story_with_text.id,
GenerationJob.output_mode == "asset_retry",
)
)
).scalars().all()
assert len(jobs) == 1
job = jobs[0]
assert job.status == "completed"
assert job.current_step == "asset_retry_completed"
assert job.result_snapshot["retryable_assets"] == ["audio"]
events = (
await db_session.execute(
select(GenerationJobEvent)
.where(GenerationJobEvent.job_id == job.id)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
assert [event.event_type for event in events] == [
"request_accepted",
"asset_retry_started",
"cover_image_started",
"cover_image_succeeded",
"asset_retry_completed",
]
assert events[3].event_metadata["asset"] == "cover_image"
finally:
app.dependency_overrides.clear()
async def test_storybook_generation_records_page_image_events(
db_session,
auth_token,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
try:
with patch(
"app.services.story_service.generate_storybook",
new_callable=AsyncMock,
) as mock_storybook:
with patch(
"app.services.story_service.generate_image",
new_callable=AsyncMock,
) as mock_image:
mock_storybook.return_value = build_storybook_output()
mock_image.side_effect = [
"https://example.com/storybook-cover.png",
"https://example.com/storybook-page-1.png",
"https://example.com/storybook-page-2.png",
]
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post(
"/api/generations",
json={
"output_mode": "storybook",
"type": "keywords",
"data": "森林, 发光, 友情",
"page_count": 6,
"generate_images": True,
},
)
assert response.status_code == 200
data = response.json()
assert data["mode"] == "storybook"
assert data["image_status"] == "ready"
job = (
await db_session.execute(
select(GenerationJob).where(
GenerationJob.story_id == data["id"],
GenerationJob.output_mode == "storybook",
)
)
).scalar_one()
events = (
await db_session.execute(
select(GenerationJobEvent)
.where(GenerationJobEvent.job_id == job.id)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
assert [event.event_type for event in events] == [
"request_accepted",
"context_prepared",
"narrative_generated",
"storybook_images_started",
"storybook_cover_image_succeeded",
"storybook_page_image_succeeded",
"storybook_page_image_succeeded",
"storybook_images_completed",
"story_saved",
"generation_completed",
]
page_events = [
event
for event in events
if event.event_type == "storybook_page_image_succeeded"
]
assert [event.event_metadata["page_number"] for event in page_events] == [1, 2]
assert events[7].event_metadata["completed_pages"] == [1, 2]
finally:
app.dependency_overrides.clear()
async def test_provider_call_events_record_latency_and_cost(
db_session,
test_user,
):
from app.services import provider_router
mock_result = StoryOutput(
mode="generated",
title="带供应商轨迹的故事",
story_text="一只小鹿学会了复盘。",
cover_prompt_suggestion="A deer with a golden bookmark",
)
class MockAdapter:
estimated_cost = 0.0123
def __init__(self, config):
self.config = config
async def execute(self, **kwargs):
return mock_result
job = await create_generation_job(
db_session,
user_id=test_user.id,
output_mode="story",
input_type="keywords",
request_payload={"data": "小鹿"},
)
with patch.object(
provider_router,
"_get_providers_with_config",
new_callable=AsyncMock,
) as mock_providers:
mock_providers.return_value = [("demo", AdapterConfig(api_key=""), None)]
with patch.object(provider_router.AdapterRegistry, "get", return_value=MockAdapter):
result = await provider_router.generate_story_content(
input_type="keywords",
data="小鹿",
db=db_session,
generation_job=job,
)
assert result == mock_result
events = (
await db_session.execute(
select(GenerationJobEvent)
.where(GenerationJobEvent.job_id == job.id)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
assert [event.event_type for event in events] == [
"request_accepted",
"provider_call_started",
"provider_call_succeeded",
]
provider_event = events[2]
assert provider_event.event_metadata["capability"] == "text"
assert provider_event.event_metadata["adapter"] == "demo"
assert provider_event.event_metadata["strategy"] == "priority"
assert provider_event.event_metadata["latency_ms"] >= 0
assert provider_event.event_metadata["estimated_cost_usd"] == 0.0123
async def test_story_provider_stats_aggregate_job_events(
db_session,
auth_token,
degraded_story_with_text,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
job = await create_generation_job(
db_session,
user_id=degraded_story_with_text.user_id,
output_mode="asset_retry",
input_type="image",
request_payload={"assets": ["image"]},
story_id=degraded_story_with_text.id,
)
await record_generation_event(
db_session,
job=job,
story_id=degraded_story_with_text.id,
event_type="provider_call_succeeded",
status="succeeded",
metadata={
"capability": "image",
"adapter": "demo",
"strategy": "priority",
"latency_ms": 42,
"estimated_cost_usd": 0.01,
},
)
await record_generation_event(
db_session,
job=job,
story_id=degraded_story_with_text.id,
event_type="provider_call_failed",
status="failed",
metadata={
"capability": "image",
"adapter": "cqtai",
"strategy": "priority",
"latency_ms": 120,
"estimated_cost_usd": 0.02,
"error": "timeout",
},
)
transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.get(
f"/api/generations/{degraded_story_with_text.id}/provider-stats"
)
assert response.status_code == 200
data = response.json()
assert data["story_id"] == degraded_story_with_text.id
assert data["total_calls"] == 2
assert data["successful_calls"] == 1
assert data["failed_calls"] == 1
assert data["avg_latency_ms"] == 81.0
assert data["estimated_cost_usd"] == 0.01
assert data["by_provider"] == [
{
"capability": "image",
"adapter": "cqtai",
"call_count": 1,
"success_count": 0,
"failure_count": 1,
"avg_latency_ms": 120.0,
"estimated_cost_usd": 0.0,
},
{
"capability": "image",
"adapter": "demo",
"call_count": 1,
"success_count": 1,
"failure_count": 0,
"avg_latency_ms": 42.0,
"estimated_cost_usd": 0.01,
},
]
finally:
app.dependency_overrides.clear()
async def test_user_provider_analytics_aggregate_across_stories(
db_session,
auth_token,
degraded_story_with_text,
test_story,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
image_job = await create_generation_job(
db_session,
user_id=degraded_story_with_text.user_id,
output_mode="asset_retry",
input_type="image",
request_payload={"assets": ["image"]},
story_id=degraded_story_with_text.id,
)
await record_generation_event(
db_session,
job=image_job,
story_id=degraded_story_with_text.id,
event_type="provider_call_succeeded",
status="succeeded",
metadata={
"capability": "image",
"adapter": "demo",
"strategy": "priority",
"latency_ms": 42,
"estimated_cost_usd": 0.01,
},
)
await record_generation_event(
db_session,
job=image_job,
story_id=degraded_story_with_text.id,
event_type="provider_call_failed",
status="failed",
metadata={
"capability": "image",
"adapter": "cqtai",
"strategy": "priority",
"latency_ms": 120,
"error": "timeout",
},
)
audio_job = await create_generation_job(
db_session,
user_id=test_story.user_id,
output_mode="asset_retry",
input_type="audio",
request_payload={"assets": ["audio"]},
story_id=test_story.id,
)
await record_generation_event(
db_session,
job=audio_job,
story_id=test_story.id,
event_type="provider_call_succeeded",
status="succeeded",
metadata={
"capability": "tts",
"adapter": "edge_tts",
"strategy": "priority",
"latency_ms": 18,
"estimated_cost_usd": 0.003,
},
)
transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.get("/api/generations/provider-analytics")
assert response.status_code == 200
data = response.json()
assert data["job_count"] == 2
assert data["story_count"] == 2
assert data["total_calls"] == 3
assert data["successful_calls"] == 2
assert data["failed_calls"] == 1
assert data["avg_latency_ms"] == 60.0
assert data["estimated_cost_usd"] == 0.013
assert data["by_provider"] == [
{
"capability": "image",
"adapter": "cqtai",
"call_count": 1,
"success_count": 0,
"failure_count": 1,
"avg_latency_ms": 120.0,
"estimated_cost_usd": 0.0,
},
{
"capability": "image",
"adapter": "demo",
"call_count": 1,
"success_count": 1,
"failure_count": 0,
"avg_latency_ms": 42.0,
"estimated_cost_usd": 0.01,
},
{
"capability": "tts",
"adapter": "edge_tts",
"call_count": 1,
"success_count": 1,
"failure_count": 0,
"avg_latency_ms": 18.0,
"estimated_cost_usd": 0.003,
},
]
finally:
app.dependency_overrides.clear()