1036 lines
35 KiB
Python
1036 lines
35 KiB
Python
"""Generation job tracking tests."""
|
|
|
|
from datetime import datetime, timedelta, timezone
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
from httpx import ASGITransport, AsyncClient
|
|
from sqlalchemy import select
|
|
|
|
from app.db.database import get_db
|
|
from app.db.models import GenerationJob, GenerationJobEvent
|
|
from app.main import app
|
|
from app.services.adapters import AdapterConfig
|
|
from app.services.adapters.storybook.primary import Storybook, StorybookPage
|
|
from app.services.adapters.text.models import StoryOutput
|
|
from app.services.generation_jobs import (
|
|
create_generation_job,
|
|
finish_generation_job,
|
|
get_generation_job_detail,
|
|
mark_stale_generation_jobs,
|
|
record_generation_event,
|
|
)
|
|
from app.services.story_service import run_generation_job_service
|
|
|
|
pytestmark = pytest.mark.asyncio
|
|
|
|
|
|
def build_storybook_output() -> Storybook:
|
|
"""Create a reusable mocked storybook payload."""
|
|
|
|
return Storybook(
|
|
title="森林里的发光冒险",
|
|
main_character="小兔子露露",
|
|
art_style="温暖水彩",
|
|
cover_prompt="A glowing forest storybook cover",
|
|
pages=[
|
|
StorybookPage(
|
|
page_number=1,
|
|
text="露露第一次走进会发光的森林。",
|
|
image_prompt="Lulu entering a glowing forest",
|
|
),
|
|
StorybookPage(
|
|
page_number=2,
|
|
text="她遇到了一只会唱歌的萤火虫。",
|
|
image_prompt="Lulu meeting a singing firefly",
|
|
),
|
|
],
|
|
)
|
|
|
|
|
|
async def test_unified_generation_is_queued_then_worker_persists_story_and_events(
|
|
db_session,
|
|
test_user,
|
|
auth_token,
|
|
mock_text_provider,
|
|
):
|
|
async def override_get_db():
|
|
yield db_session
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
transport = ASGITransport(app=app)
|
|
task_delay_path = (
|
|
"app.tasks.generation_workflow.run_generation_workflow_task.delay"
|
|
)
|
|
|
|
try:
|
|
with patch(task_delay_path) as mock_delay:
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
client.cookies.set("access_token", auth_token)
|
|
|
|
response = await client.post(
|
|
"/api/generations",
|
|
json={
|
|
"output_mode": "story",
|
|
"type": "keywords",
|
|
"data": "小兔子, 森林",
|
|
"generate_images": False,
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 202
|
|
data = response.json()
|
|
assert data["id"] is None
|
|
assert data["generation_status"] == "queued"
|
|
assert data["text_status"] == "generating"
|
|
assert data["retryable_assets"] == []
|
|
assert data["generation_job_id"]
|
|
mock_delay.assert_called_once_with(data["generation_job_id"])
|
|
|
|
jobs = (
|
|
await db_session.execute(
|
|
select(GenerationJob).where(GenerationJob.user_id == test_user.id)
|
|
)
|
|
).scalars().all()
|
|
assert len(jobs) == 1
|
|
job = jobs[0]
|
|
assert job.story_id is None
|
|
assert job.output_mode == "story"
|
|
assert job.input_type == "keywords"
|
|
assert job.status == "running"
|
|
assert job.current_step == "request_accepted"
|
|
assert data["generation_job_id"] == job.id
|
|
|
|
await run_generation_job_service(job.id, db_session)
|
|
|
|
job = (
|
|
await db_session.execute(
|
|
select(GenerationJob).where(GenerationJob.id == job.id)
|
|
)
|
|
).scalar_one()
|
|
assert job.story_id is not None
|
|
assert job.status == "completed"
|
|
assert job.current_step == "generation_completed"
|
|
assert job.result_snapshot["retryable_assets"] == ["image", "audio"]
|
|
|
|
events = (
|
|
await db_session.execute(
|
|
select(GenerationJobEvent)
|
|
.where(GenerationJobEvent.job_id == job.id)
|
|
.order_by(GenerationJobEvent.id)
|
|
)
|
|
).scalars().all()
|
|
assert [event.event_type for event in events] == [
|
|
"request_accepted",
|
|
"worker_started",
|
|
"context_prepared",
|
|
"narrative_generated",
|
|
"story_saved",
|
|
"generation_completed",
|
|
]
|
|
assert events[2].event_metadata["has_memory_context"] is False
|
|
assert events[3].event_metadata["title"] == "小兔子的冒险"
|
|
assert events[4].story_id == job.story_id
|
|
|
|
detail_response = await client.get(f"/api/generations/jobs/{job.id}")
|
|
assert detail_response.status_code == 200
|
|
detail = detail_response.json()
|
|
assert detail["id"] == job.id
|
|
assert detail["story_id"] == job.story_id
|
|
assert detail["progress_percent"] == 100
|
|
assert detail["progress_label"] == "已完成"
|
|
assert detail["is_terminal"] is True
|
|
assert [event["event_type"] for event in detail["events"]] == [
|
|
"request_accepted",
|
|
"worker_started",
|
|
"context_prepared",
|
|
"narrative_generated",
|
|
"story_saved",
|
|
"generation_completed",
|
|
]
|
|
|
|
story_response = await client.get(f"/api/generations/{job.story_id}")
|
|
assert story_response.status_code == 200
|
|
story_data = story_response.json()
|
|
assert story_data["generation_status"] == "partial_ready"
|
|
assert story_data["retryable_assets"] == ["image", "audio"]
|
|
|
|
list_response = await client.get(f"/api/generations/{job.story_id}/jobs")
|
|
assert list_response.status_code == 200
|
|
job_list = list_response.json()
|
|
assert [item["id"] for item in job_list] == [job.id]
|
|
assert job_list[0]["progress_percent"] == 100
|
|
assert job_list[0]["is_terminal"] is True
|
|
finally:
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
async def test_asset_retry_records_job_events_and_updates_retryable_assets(
|
|
db_session,
|
|
test_user,
|
|
auth_token,
|
|
degraded_story_with_text,
|
|
mock_image_provider,
|
|
):
|
|
async def override_get_db():
|
|
yield db_session
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
transport = ASGITransport(app=app)
|
|
|
|
try:
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
client.cookies.set("access_token", auth_token)
|
|
|
|
response = await client.post(
|
|
f"/api/generations/{degraded_story_with_text.id}/retry-assets",
|
|
json={"assets": ["image"]},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["image_status"] == "ready"
|
|
assert data["retryable_assets"] == ["audio"]
|
|
|
|
jobs = (
|
|
await db_session.execute(
|
|
select(GenerationJob).where(
|
|
GenerationJob.story_id == degraded_story_with_text.id,
|
|
GenerationJob.output_mode == "asset_retry",
|
|
)
|
|
)
|
|
).scalars().all()
|
|
assert len(jobs) == 1
|
|
job = jobs[0]
|
|
assert job.status == "completed"
|
|
assert job.current_step == "asset_retry_completed"
|
|
assert job.result_snapshot["retryable_assets"] == ["audio"]
|
|
|
|
events = (
|
|
await db_session.execute(
|
|
select(GenerationJobEvent)
|
|
.where(GenerationJobEvent.job_id == job.id)
|
|
.order_by(GenerationJobEvent.id)
|
|
)
|
|
).scalars().all()
|
|
assert [event.event_type for event in events] == [
|
|
"request_accepted",
|
|
"asset_retry_started",
|
|
"cover_image_started",
|
|
"cover_image_succeeded",
|
|
"asset_retry_completed",
|
|
]
|
|
assert events[3].event_metadata["asset"] == "cover_image"
|
|
finally:
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
async def test_storybook_generation_is_queued_then_worker_records_page_image_events(
|
|
db_session,
|
|
auth_token,
|
|
):
|
|
async def override_get_db():
|
|
yield db_session
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
transport = ASGITransport(app=app)
|
|
|
|
try:
|
|
with patch(
|
|
"app.services.story_service.generate_storybook",
|
|
new_callable=AsyncMock,
|
|
) as mock_storybook:
|
|
with patch(
|
|
"app.services.story_service.generate_image",
|
|
new_callable=AsyncMock,
|
|
) as mock_image:
|
|
mock_storybook.return_value = build_storybook_output()
|
|
mock_image.side_effect = [
|
|
"https://example.com/storybook-cover.png",
|
|
"https://example.com/storybook-page-1.png",
|
|
"https://example.com/storybook-page-2.png",
|
|
]
|
|
|
|
task_delay_path = (
|
|
"app.tasks.generation_workflow.run_generation_workflow_task.delay"
|
|
)
|
|
with patch(task_delay_path) as mock_delay:
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
client.cookies.set("access_token", auth_token)
|
|
|
|
response = await client.post(
|
|
"/api/generations",
|
|
json={
|
|
"output_mode": "storybook",
|
|
"type": "keywords",
|
|
"data": "森林, 发光, 友情",
|
|
"page_count": 6,
|
|
"generate_images": True,
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 202
|
|
data = response.json()
|
|
assert data["id"] is None
|
|
assert data["mode"] == "storybook"
|
|
assert data["generation_status"] == "queued"
|
|
assert data["text_status"] == "generating"
|
|
mock_delay.assert_called_once_with(data["generation_job_id"])
|
|
job = (
|
|
await db_session.execute(
|
|
select(GenerationJob).where(
|
|
GenerationJob.id == data["generation_job_id"],
|
|
)
|
|
)
|
|
).scalar_one()
|
|
await run_generation_job_service(job.id, db_session)
|
|
|
|
job = (
|
|
await db_session.execute(
|
|
select(GenerationJob).where(
|
|
GenerationJob.id == job.id,
|
|
)
|
|
)
|
|
).scalar_one()
|
|
assert job.story_id is not None
|
|
assert job.status == "completed"
|
|
|
|
events = (
|
|
await db_session.execute(
|
|
select(GenerationJobEvent)
|
|
.where(GenerationJobEvent.job_id == job.id)
|
|
.order_by(GenerationJobEvent.id)
|
|
)
|
|
).scalars().all()
|
|
|
|
assert [event.event_type for event in events] == [
|
|
"request_accepted",
|
|
"worker_started",
|
|
"context_prepared",
|
|
"narrative_generated",
|
|
"storybook_images_started",
|
|
"storybook_cover_image_succeeded",
|
|
"storybook_page_image_succeeded",
|
|
"storybook_page_image_succeeded",
|
|
"storybook_images_completed",
|
|
"story_saved",
|
|
"generation_completed",
|
|
]
|
|
page_events = [
|
|
event
|
|
for event in events
|
|
if event.event_type == "storybook_page_image_succeeded"
|
|
]
|
|
assert [event.event_metadata["page_number"] for event in page_events] == [1, 2]
|
|
assert events[8].event_metadata["completed_pages"] == [1, 2]
|
|
finally:
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
async def test_provider_call_events_record_latency_and_cost(
|
|
db_session,
|
|
test_user,
|
|
):
|
|
from app.services import provider_router
|
|
|
|
mock_result = StoryOutput(
|
|
mode="generated",
|
|
title="带供应商轨迹的故事",
|
|
story_text="一只小鹿学会了复盘。",
|
|
cover_prompt_suggestion="A deer with a golden bookmark",
|
|
)
|
|
|
|
class MockAdapter:
|
|
estimated_cost = 0.0123
|
|
|
|
def __init__(self, config):
|
|
self.config = config
|
|
|
|
async def execute(self, **kwargs):
|
|
return mock_result
|
|
|
|
job = await create_generation_job(
|
|
db_session,
|
|
user_id=test_user.id,
|
|
output_mode="story",
|
|
input_type="keywords",
|
|
request_payload={"data": "小鹿"},
|
|
)
|
|
|
|
with patch.object(
|
|
provider_router,
|
|
"_get_providers_with_config",
|
|
new_callable=AsyncMock,
|
|
) as mock_providers:
|
|
mock_providers.return_value = [("demo", AdapterConfig(api_key=""), None)]
|
|
|
|
with patch.object(provider_router.AdapterRegistry, "get", return_value=MockAdapter):
|
|
result = await provider_router.generate_story_content(
|
|
input_type="keywords",
|
|
data="小鹿",
|
|
db=db_session,
|
|
generation_job=job,
|
|
)
|
|
|
|
assert result == mock_result
|
|
|
|
events = (
|
|
await db_session.execute(
|
|
select(GenerationJobEvent)
|
|
.where(GenerationJobEvent.job_id == job.id)
|
|
.order_by(GenerationJobEvent.id)
|
|
)
|
|
).scalars().all()
|
|
|
|
assert [event.event_type for event in events] == [
|
|
"request_accepted",
|
|
"provider_call_started",
|
|
"provider_call_succeeded",
|
|
]
|
|
provider_event = events[2]
|
|
assert provider_event.event_metadata["capability"] == "text"
|
|
assert provider_event.event_metadata["adapter"] == "demo"
|
|
assert provider_event.event_metadata["strategy"] == "priority"
|
|
assert provider_event.event_metadata["latency_ms"] >= 0
|
|
assert provider_event.event_metadata["estimated_cost_usd"] == 0.0123
|
|
|
|
|
|
async def test_story_provider_stats_aggregate_job_events(
|
|
db_session,
|
|
auth_token,
|
|
degraded_story_with_text,
|
|
):
|
|
async def override_get_db():
|
|
yield db_session
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
|
|
job = await create_generation_job(
|
|
db_session,
|
|
user_id=degraded_story_with_text.user_id,
|
|
output_mode="asset_retry",
|
|
input_type="image",
|
|
request_payload={"assets": ["image"]},
|
|
story_id=degraded_story_with_text.id,
|
|
)
|
|
await record_generation_event(
|
|
db_session,
|
|
job=job,
|
|
story_id=degraded_story_with_text.id,
|
|
event_type="provider_call_succeeded",
|
|
status="succeeded",
|
|
metadata={
|
|
"capability": "image",
|
|
"adapter": "demo",
|
|
"strategy": "priority",
|
|
"latency_ms": 42,
|
|
"estimated_cost_usd": 0.01,
|
|
},
|
|
)
|
|
await record_generation_event(
|
|
db_session,
|
|
job=job,
|
|
story_id=degraded_story_with_text.id,
|
|
event_type="provider_call_failed",
|
|
status="failed",
|
|
metadata={
|
|
"capability": "image",
|
|
"adapter": "cqtai",
|
|
"strategy": "priority",
|
|
"latency_ms": 120,
|
|
"estimated_cost_usd": 0.02,
|
|
"error": "timeout",
|
|
},
|
|
)
|
|
|
|
transport = ASGITransport(app=app)
|
|
try:
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
client.cookies.set("access_token", auth_token)
|
|
|
|
response = await client.get(
|
|
f"/api/generations/{degraded_story_with_text.id}/provider-stats"
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["story_id"] == degraded_story_with_text.id
|
|
assert data["total_calls"] == 2
|
|
assert data["successful_calls"] == 1
|
|
assert data["failed_calls"] == 1
|
|
assert data["avg_latency_ms"] == 81.0
|
|
assert data["estimated_cost_usd"] == 0.01
|
|
assert data["by_provider"] == [
|
|
{
|
|
"capability": "image",
|
|
"adapter": "cqtai",
|
|
"call_count": 1,
|
|
"success_count": 0,
|
|
"failure_count": 1,
|
|
"avg_latency_ms": 120.0,
|
|
"estimated_cost_usd": 0.0,
|
|
},
|
|
{
|
|
"capability": "image",
|
|
"adapter": "demo",
|
|
"call_count": 1,
|
|
"success_count": 1,
|
|
"failure_count": 0,
|
|
"avg_latency_ms": 42.0,
|
|
"estimated_cost_usd": 0.01,
|
|
},
|
|
]
|
|
finally:
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
async def test_user_provider_analytics_aggregate_across_stories(
|
|
db_session,
|
|
auth_token,
|
|
degraded_story_with_text,
|
|
test_story,
|
|
):
|
|
async def override_get_db():
|
|
yield db_session
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
|
|
image_job = await create_generation_job(
|
|
db_session,
|
|
user_id=degraded_story_with_text.user_id,
|
|
output_mode="asset_retry",
|
|
input_type="image",
|
|
request_payload={"assets": ["image"]},
|
|
story_id=degraded_story_with_text.id,
|
|
)
|
|
await record_generation_event(
|
|
db_session,
|
|
job=image_job,
|
|
story_id=degraded_story_with_text.id,
|
|
event_type="provider_call_succeeded",
|
|
status="succeeded",
|
|
metadata={
|
|
"capability": "image",
|
|
"adapter": "demo",
|
|
"strategy": "priority",
|
|
"latency_ms": 42,
|
|
"estimated_cost_usd": 0.01,
|
|
},
|
|
)
|
|
await record_generation_event(
|
|
db_session,
|
|
job=image_job,
|
|
story_id=degraded_story_with_text.id,
|
|
event_type="provider_call_failed",
|
|
status="failed",
|
|
metadata={
|
|
"capability": "image",
|
|
"adapter": "cqtai",
|
|
"strategy": "priority",
|
|
"latency_ms": 120,
|
|
"error": "timeout",
|
|
},
|
|
)
|
|
|
|
audio_job = await create_generation_job(
|
|
db_session,
|
|
user_id=test_story.user_id,
|
|
output_mode="asset_retry",
|
|
input_type="audio",
|
|
request_payload={"assets": ["audio"]},
|
|
story_id=test_story.id,
|
|
)
|
|
await record_generation_event(
|
|
db_session,
|
|
job=audio_job,
|
|
story_id=test_story.id,
|
|
event_type="provider_call_succeeded",
|
|
status="succeeded",
|
|
metadata={
|
|
"capability": "tts",
|
|
"adapter": "edge_tts",
|
|
"strategy": "priority",
|
|
"latency_ms": 18,
|
|
"estimated_cost_usd": 0.003,
|
|
},
|
|
)
|
|
|
|
transport = ASGITransport(app=app)
|
|
try:
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
client.cookies.set("access_token", auth_token)
|
|
|
|
response = await client.get("/api/generations/provider-analytics")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["job_count"] == 2
|
|
assert data["story_count"] == 2
|
|
assert data["total_calls"] == 3
|
|
assert data["successful_calls"] == 2
|
|
assert data["failed_calls"] == 1
|
|
assert data["avg_latency_ms"] == 60.0
|
|
assert data["estimated_cost_usd"] == 0.013
|
|
assert data["failure_reasons"] == [{"reason": "timeout", "count": 1}]
|
|
assert data["by_provider"] == [
|
|
{
|
|
"capability": "image",
|
|
"adapter": "cqtai",
|
|
"call_count": 1,
|
|
"success_count": 0,
|
|
"failure_count": 1,
|
|
"avg_latency_ms": 120.0,
|
|
"estimated_cost_usd": 0.0,
|
|
},
|
|
{
|
|
"capability": "image",
|
|
"adapter": "demo",
|
|
"call_count": 1,
|
|
"success_count": 1,
|
|
"failure_count": 0,
|
|
"avg_latency_ms": 42.0,
|
|
"estimated_cost_usd": 0.01,
|
|
},
|
|
{
|
|
"capability": "tts",
|
|
"adapter": "edge_tts",
|
|
"call_count": 1,
|
|
"success_count": 1,
|
|
"failure_count": 0,
|
|
"avg_latency_ms": 18.0,
|
|
"estimated_cost_usd": 0.003,
|
|
},
|
|
]
|
|
finally:
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
async def test_provider_analytics_support_days_and_capability_filters(
|
|
db_session,
|
|
auth_token,
|
|
degraded_story_with_text,
|
|
test_story,
|
|
):
|
|
async def override_get_db():
|
|
yield db_session
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
|
|
image_job = await create_generation_job(
|
|
db_session,
|
|
user_id=degraded_story_with_text.user_id,
|
|
output_mode="asset_retry",
|
|
input_type="image",
|
|
request_payload={"assets": ["image"]},
|
|
story_id=degraded_story_with_text.id,
|
|
)
|
|
old_event = await record_generation_event(
|
|
db_session,
|
|
job=image_job,
|
|
story_id=degraded_story_with_text.id,
|
|
event_type="provider_call_failed",
|
|
status="failed",
|
|
metadata={
|
|
"capability": "image",
|
|
"adapter": "cqtai",
|
|
"strategy": "priority",
|
|
"latency_ms": 120,
|
|
"error": "timeout",
|
|
},
|
|
)
|
|
old_event.created_at = datetime.now(timezone.utc) - timedelta(days=10)
|
|
await db_session.commit()
|
|
|
|
tts_job = await create_generation_job(
|
|
db_session,
|
|
user_id=test_story.user_id,
|
|
output_mode="asset_retry",
|
|
input_type="audio",
|
|
request_payload={"assets": ["audio"]},
|
|
story_id=test_story.id,
|
|
)
|
|
await record_generation_event(
|
|
db_session,
|
|
job=tts_job,
|
|
story_id=test_story.id,
|
|
event_type="provider_call_succeeded",
|
|
status="succeeded",
|
|
metadata={
|
|
"capability": "tts",
|
|
"adapter": "edge_tts",
|
|
"strategy": "priority",
|
|
"latency_ms": 18,
|
|
"estimated_cost_usd": 0.003,
|
|
},
|
|
)
|
|
|
|
transport = ASGITransport(app=app)
|
|
try:
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
client.cookies.set("access_token", auth_token)
|
|
|
|
response = await client.get("/api/generations/provider-analytics?days=7")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["window_days"] == 7
|
|
assert data["total_calls"] == 1
|
|
assert data["job_count"] == 1
|
|
assert data["story_count"] == 1
|
|
assert data["failure_reasons"] == []
|
|
|
|
response = await client.get(
|
|
"/api/generations/provider-analytics?capability=image"
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["capability"] == "image"
|
|
assert data["total_calls"] == 1
|
|
assert data["failed_calls"] == 1
|
|
assert data["job_count"] == 1
|
|
assert data["story_count"] == 1
|
|
assert data["failure_reasons"] == [{"reason": "timeout", "count": 1}]
|
|
|
|
response = await client.get(
|
|
f"/api/generations/{degraded_story_with_text.id}/provider-stats?capability=image"
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["capability"] == "image"
|
|
assert data["failure_reasons"] == [{"reason": "timeout", "count": 1}]
|
|
finally:
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
async def test_generation_ops_summary_exposes_running_stale_and_recent_failures(
|
|
db_session,
|
|
auth_token,
|
|
degraded_story_with_text,
|
|
test_story,
|
|
):
|
|
async def override_get_db():
|
|
yield db_session
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
|
|
running_job = await create_generation_job(
|
|
db_session,
|
|
user_id=test_story.user_id,
|
|
output_mode="story",
|
|
input_type="keywords",
|
|
request_payload={"data": "星星"},
|
|
story_id=test_story.id,
|
|
)
|
|
stale_job = await create_generation_job(
|
|
db_session,
|
|
user_id=degraded_story_with_text.user_id,
|
|
output_mode="asset_generation",
|
|
input_type="image",
|
|
request_payload={"story_id": degraded_story_with_text.id},
|
|
story_id=degraded_story_with_text.id,
|
|
)
|
|
failed_job = await create_generation_job(
|
|
db_session,
|
|
user_id=degraded_story_with_text.user_id,
|
|
output_mode="asset_retry",
|
|
input_type="image",
|
|
request_payload={"assets": ["image"]},
|
|
story_id=degraded_story_with_text.id,
|
|
)
|
|
degraded_job = await create_generation_job(
|
|
db_session,
|
|
user_id=test_story.user_id,
|
|
output_mode="storybook",
|
|
input_type="keywords",
|
|
request_payload={"data": "月亮"},
|
|
story_id=test_story.id,
|
|
)
|
|
|
|
stale_job.updated_at = datetime.now(timezone.utc) - timedelta(hours=3)
|
|
failed_job.status = "failed"
|
|
failed_job.current_step = "asset_retry_failed"
|
|
failed_job.error_message = "image timeout"
|
|
failed_job.updated_at = datetime.now(timezone.utc) - timedelta(hours=1)
|
|
degraded_job.status = "degraded_completed"
|
|
degraded_job.current_step = "generation_completed"
|
|
degraded_job.updated_at = datetime.now(timezone.utc) - timedelta(minutes=30)
|
|
running_job.updated_at = datetime.now(timezone.utc) - timedelta(minutes=10)
|
|
await db_session.commit()
|
|
|
|
transport = ASGITransport(app=app)
|
|
try:
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
client.cookies.set("access_token", auth_token)
|
|
|
|
response = await client.get("/api/generations/ops-summary?hours=48")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["window_hours"] == 48
|
|
assert data["active_jobs"] == 2
|
|
assert data["stale_running_jobs"] == 1
|
|
assert data["failed_jobs"] == 1
|
|
assert data["degraded_jobs"] == 1
|
|
assert data["asset_retry_jobs"] == 2
|
|
assert len(data["recent_failures"]) == 1
|
|
assert data["recent_failures"][0]["job_id"] == failed_job.id
|
|
assert data["recent_failures"][0]["story_title"] == degraded_story_with_text.title
|
|
assert data["recent_failures"][0]["failure_label"] == "资源重试失败"
|
|
finally:
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
async def test_mark_stale_generation_jobs_marks_old_running_jobs_failed(
|
|
db_session,
|
|
degraded_story_with_text,
|
|
):
|
|
stale_job = await create_generation_job(
|
|
db_session,
|
|
user_id=degraded_story_with_text.user_id,
|
|
output_mode="story",
|
|
input_type="keywords",
|
|
request_payload={"data": "超时任务"},
|
|
story_id=degraded_story_with_text.id,
|
|
)
|
|
stale_job.updated_at = datetime.now(timezone.utc) - timedelta(hours=2)
|
|
await db_session.commit()
|
|
|
|
result = await mark_stale_generation_jobs(db_session, stale_after_minutes=30)
|
|
|
|
assert result == {"running": 1, "marked_stale": 1, "stale_after_minutes": 30}
|
|
|
|
refreshed_job = (
|
|
await db_session.execute(select(GenerationJob).where(GenerationJob.id == stale_job.id))
|
|
).scalar_one()
|
|
assert refreshed_job.status == "failed"
|
|
assert refreshed_job.current_step == "generation_stale_failed"
|
|
assert refreshed_job.error_message == "Generation job exceeded 30 minutes without progress."
|
|
|
|
events = (
|
|
await db_session.execute(
|
|
select(GenerationJobEvent)
|
|
.where(GenerationJobEvent.job_id == stale_job.id)
|
|
.order_by(GenerationJobEvent.id)
|
|
)
|
|
).scalars().all()
|
|
assert events[-1].event_type == "generation_stale_failed"
|
|
assert events[-1].event_metadata["stale_after_minutes"] == 30
|
|
|
|
|
|
async def test_retry_assets_rejects_when_story_has_active_job(
|
|
db_session,
|
|
auth_token,
|
|
degraded_story_with_text,
|
|
):
|
|
async def override_get_db():
|
|
yield db_session
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
|
|
await create_generation_job(
|
|
db_session,
|
|
user_id=degraded_story_with_text.user_id,
|
|
output_mode="asset_generation",
|
|
input_type="image",
|
|
request_payload={"story_id": degraded_story_with_text.id},
|
|
story_id=degraded_story_with_text.id,
|
|
)
|
|
|
|
transport = ASGITransport(app=app)
|
|
try:
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
client.cookies.set("access_token", auth_token)
|
|
|
|
response = await client.post(
|
|
f"/api/generations/{degraded_story_with_text.id}/retry-assets",
|
|
json={"assets": ["image"]},
|
|
)
|
|
|
|
assert response.status_code == 409
|
|
assert "已有运行中的任务" in response.json()["detail"]
|
|
finally:
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
async def test_cancel_queued_generation_job_marks_it_canceled(
|
|
db_session,
|
|
auth_token,
|
|
test_user,
|
|
):
|
|
async def override_get_db():
|
|
yield db_session
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
transport = ASGITransport(app=app)
|
|
task_delay_path = "app.tasks.generation_workflow.run_generation_workflow_task.delay"
|
|
|
|
try:
|
|
with patch(task_delay_path) as mock_delay:
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
client.cookies.set("access_token", auth_token)
|
|
|
|
response = await client.post(
|
|
"/api/generations",
|
|
json={
|
|
"output_mode": "story",
|
|
"type": "keywords",
|
|
"data": "小狐狸, 月亮船",
|
|
"generate_images": False,
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 202
|
|
job_id = response.json()["generation_job_id"]
|
|
mock_delay.assert_called_once_with(job_id)
|
|
|
|
cancel_response = await client.post(f"/api/generations/jobs/{job_id}/cancel")
|
|
|
|
assert cancel_response.status_code == 200
|
|
canceled_job = cancel_response.json()
|
|
assert canceled_job["status"] == "canceled"
|
|
assert canceled_job["current_step"] == "generation_canceled"
|
|
assert canceled_job["can_cancel"] is False
|
|
assert canceled_job["can_retry"] is True
|
|
|
|
detail = await get_generation_job_detail(
|
|
db_session,
|
|
job_id=job_id,
|
|
user_id=test_user.id,
|
|
)
|
|
assert [event["event_type"] for event in detail["events"]] == [
|
|
"request_accepted",
|
|
"generation_canceled",
|
|
]
|
|
finally:
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
async def test_cancel_running_generation_job_marks_cancel_requested(
|
|
db_session,
|
|
auth_token,
|
|
test_user,
|
|
):
|
|
async def override_get_db():
|
|
yield db_session
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
transport = ASGITransport(app=app)
|
|
|
|
job = await create_generation_job(
|
|
db_session,
|
|
user_id=test_user.id,
|
|
output_mode="story",
|
|
input_type="keywords",
|
|
request_payload={
|
|
"output_mode": "story",
|
|
"type": "keywords",
|
|
"data": "小熊, 森林",
|
|
"generate_images": False,
|
|
},
|
|
)
|
|
await record_generation_event(
|
|
db_session,
|
|
job=job,
|
|
event_type="worker_started",
|
|
status="running",
|
|
message="Generation worker started processing the accepted request.",
|
|
)
|
|
|
|
try:
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
client.cookies.set("access_token", auth_token)
|
|
|
|
response = await client.post(f"/api/generations/jobs/{job.id}/cancel")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "running"
|
|
assert data["current_step"] == "cancel_requested"
|
|
assert data["can_cancel"] is False
|
|
assert data["can_retry"] is False
|
|
|
|
refreshed_job = (
|
|
await db_session.execute(select(GenerationJob).where(GenerationJob.id == job.id))
|
|
).scalar_one()
|
|
assert refreshed_job.current_step == "cancel_requested"
|
|
assert refreshed_job.error_message == "Cancellation requested by user."
|
|
finally:
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
async def test_retry_failed_generation_job_requeues_new_worker_job(
|
|
db_session,
|
|
auth_token,
|
|
test_user,
|
|
mock_text_provider,
|
|
):
|
|
async def override_get_db():
|
|
yield db_session
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
transport = ASGITransport(app=app)
|
|
task_delay_path = "app.tasks.generation_workflow.run_generation_workflow_task.delay"
|
|
|
|
failed_job = await create_generation_job(
|
|
db_session,
|
|
user_id=test_user.id,
|
|
output_mode="story",
|
|
input_type="keywords",
|
|
request_payload={
|
|
"output_mode": "story",
|
|
"type": "keywords",
|
|
"data": "小鹿, 星星",
|
|
"generate_images": False,
|
|
},
|
|
)
|
|
await finish_generation_job(
|
|
db_session,
|
|
job=failed_job,
|
|
story=None,
|
|
status="failed",
|
|
current_step="generation_failed",
|
|
error_message="upstream timeout",
|
|
message="Generation failed before a durable story result was available.",
|
|
)
|
|
|
|
try:
|
|
with patch(task_delay_path) as mock_delay:
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
client.cookies.set("access_token", auth_token)
|
|
|
|
response = await client.post(f"/api/generations/jobs/{failed_job.id}/retry")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["id"] != failed_job.id
|
|
assert data["status"] == "running"
|
|
assert data["current_step"] == "retry_queued"
|
|
assert data["can_cancel"] is True
|
|
assert data["can_retry"] is False
|
|
mock_delay.assert_called_once_with(data["id"])
|
|
|
|
retried_job_id = data["id"]
|
|
await run_generation_job_service(retried_job_id, db_session)
|
|
|
|
retried_job = (
|
|
await db_session.execute(
|
|
select(GenerationJob).where(GenerationJob.id == retried_job_id)
|
|
)
|
|
).scalar_one()
|
|
assert retried_job.status == "completed"
|
|
assert retried_job.current_step == "generation_completed"
|
|
|
|
events = (
|
|
await db_session.execute(
|
|
select(GenerationJobEvent)
|
|
.where(GenerationJobEvent.job_id == retried_job_id)
|
|
.order_by(GenerationJobEvent.id)
|
|
)
|
|
).scalars().all()
|
|
assert [event.event_type for event in events[:3]] == [
|
|
"request_accepted",
|
|
"retry_queued",
|
|
"worker_started",
|
|
]
|
|
finally:
|
|
app.dependency_overrides.clear()
|