"""Generation job tracking tests.""" from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock, patch import pytest from httpx import ASGITransport, AsyncClient from sqlalchemy import select from app.db.database import get_db from app.db.models import GenerationJob, GenerationJobEvent, Story from app.main import app from app.services.adapters import AdapterConfig from app.services.adapters.storybook.primary import Storybook, StorybookPage from app.services.adapters.text.models import StoryOutput from app.services.generation_jobs import ( create_generation_job, finish_generation_job, get_generation_job_detail, mark_stale_generation_jobs, record_generation_event, ) from app.services.story_service import queue_story_asset_generation, run_generation_job_service pytestmark = pytest.mark.asyncio def build_storybook_output() -> Storybook: """Create a reusable mocked storybook payload.""" return Storybook( title="森林里的发光冒险", main_character="小兔子露露", art_style="温暖水彩", cover_prompt="A glowing forest storybook cover", pages=[ StorybookPage( page_number=1, text="露露第一次走进会发光的森林。", image_prompt="Lulu entering a glowing forest", ), StorybookPage( page_number=2, text="她遇到了一只会唱歌的萤火虫。", image_prompt="Lulu meeting a singing firefly", ), ], ) async def test_unified_generation_is_queued_then_worker_persists_story_and_events( db_session, test_user, auth_token, mock_text_provider, ): async def override_get_db(): yield db_session app.dependency_overrides[get_db] = override_get_db transport = ASGITransport(app=app) task_delay_path = ( "app.tasks.generation_workflow.run_generation_workflow_task.delay" ) try: with patch(task_delay_path) as mock_delay: async with AsyncClient(transport=transport, base_url="http://test") as client: client.cookies.set("access_token", auth_token) response = await client.post( "/api/generations", json={ "output_mode": "story", "type": "keywords", "data": "小兔子, 森林", "generate_images": False, }, ) assert response.status_code == 202 data = response.json() assert data["id"] is None assert data["generation_status"] == "queued" assert data["text_status"] == "generating" assert data["retryable_assets"] == [] assert data["generation_job_id"] mock_delay.assert_called_once_with(data["generation_job_id"]) jobs = ( await db_session.execute( select(GenerationJob).where(GenerationJob.user_id == test_user.id) ) ).scalars().all() assert len(jobs) == 1 job = jobs[0] assert job.story_id is None assert job.output_mode == "story" assert job.input_type == "keywords" assert job.status == "running" assert job.current_step == "request_accepted" assert data["generation_job_id"] == job.id await run_generation_job_service(job.id, db_session) job = ( await db_session.execute( select(GenerationJob).where(GenerationJob.id == job.id) ) ).scalar_one() assert job.story_id is not None assert job.status == "completed" assert job.current_step == "generation_completed" assert job.result_snapshot["retryable_assets"] == ["image", "audio"] events = ( await db_session.execute( select(GenerationJobEvent) .where(GenerationJobEvent.job_id == job.id) .order_by(GenerationJobEvent.id) ) ).scalars().all() assert [event.event_type for event in events] == [ "request_accepted", "worker_started", "context_prepared", "narrative_generated", "story_saved", "generation_completed", ] assert events[2].event_metadata["has_memory_context"] is False assert events[3].event_metadata["title"] == "小兔子的冒险" assert events[4].story_id == job.story_id detail_response = await client.get(f"/api/generations/jobs/{job.id}") assert detail_response.status_code == 200 detail = detail_response.json() assert detail["id"] == job.id assert detail["story_id"] == job.story_id assert detail["progress_percent"] == 100 assert detail["progress_label"] == "已完成" assert detail["is_terminal"] is True assert [event["event_type"] for event in detail["events"]] == [ "request_accepted", "worker_started", "context_prepared", "narrative_generated", "story_saved", "generation_completed", ] story_response = await client.get(f"/api/generations/{job.story_id}") assert story_response.status_code == 200 story_data = story_response.json() assert story_data["generation_status"] == "partial_ready" assert story_data["retryable_assets"] == ["image", "audio"] list_response = await client.get(f"/api/generations/{job.story_id}/jobs") assert list_response.status_code == 200 job_list = list_response.json() assert [item["id"] for item in job_list] == [job.id] assert job_list[0]["progress_percent"] == 100 assert job_list[0]["is_terminal"] is True finally: app.dependency_overrides.clear() async def test_asset_retry_records_job_events_and_updates_retryable_assets( db_session, test_user, auth_token, degraded_story_with_text, mock_image_provider, ): async def override_get_db(): yield db_session app.dependency_overrides[get_db] = override_get_db transport = ASGITransport(app=app) try: async with AsyncClient(transport=transport, base_url="http://test") as client: client.cookies.set("access_token", auth_token) response = await client.post( f"/api/generations/{degraded_story_with_text.id}/retry-assets", json={"assets": ["image"]}, ) assert response.status_code == 200 data = response.json() assert data["image_status"] == "ready" assert data["retryable_assets"] == ["audio"] jobs = ( await db_session.execute( select(GenerationJob).where( GenerationJob.story_id == degraded_story_with_text.id, GenerationJob.output_mode == "asset_retry", ) ) ).scalars().all() assert len(jobs) == 1 job = jobs[0] assert job.status == "completed" assert job.current_step == "asset_retry_completed" assert job.result_snapshot["retryable_assets"] == ["audio"] events = ( await db_session.execute( select(GenerationJobEvent) .where(GenerationJobEvent.job_id == job.id) .order_by(GenerationJobEvent.id) ) ).scalars().all() assert [event.event_type for event in events] == [ "request_accepted", "asset_retry_started", "cover_image_started", "cover_image_succeeded", "asset_retry_completed", ] assert events[3].event_metadata["asset"] == "cover_image" finally: app.dependency_overrides.clear() async def test_queue_story_asset_generation_dispatches_background_job( db_session, test_story, ): task_delay_path = "app.tasks.generation_workflow.run_generation_workflow_task.delay" with patch(task_delay_path) as mock_delay: summary = await queue_story_asset_generation( test_story.id, test_story.user_id, ["image"], db_session, ) assert summary["output_mode"] == "asset_generation" assert summary["input_type"] == "image" assert summary["status"] == "running" assert summary["current_step"] == "request_accepted" assert summary["can_cancel"] is True assert summary["can_retry"] is False mock_delay.assert_called_once_with(summary["id"]) job = ( await db_session.execute( select(GenerationJob).where(GenerationJob.id == summary["id"]) ) ).scalar_one() assert job.story_id == test_story.id assert job.output_mode == "asset_generation" async def test_asset_generation_job_worker_completes_cover_image( db_session, test_story, ): job = await create_generation_job( db_session, user_id=test_story.user_id, output_mode="asset_generation", input_type="image", request_payload={"story_id": test_story.id, "assets": ["image"]}, story_id=test_story.id, ) with patch( "app.services.story_service.generate_image", new_callable=AsyncMock, ) as mock_generate_image: mock_generate_image.return_value = "https://example.com/async-cover.png" await run_generation_job_service(job.id, db_session) refreshed_job = ( await db_session.execute(select(GenerationJob).where(GenerationJob.id == job.id)) ).scalar_one() assert refreshed_job.status == "completed" assert refreshed_job.current_step == "asset_generation_completed" assert refreshed_job.result_snapshot["image_status"] == "ready" story = ( await db_session.execute( select(Story).where(Story.id == test_story.id) ) ).scalar_one() assert story.image_url == "https://example.com/async-cover.png" events = ( await db_session.execute( select(GenerationJobEvent) .where(GenerationJobEvent.job_id == job.id) .order_by(GenerationJobEvent.id) ) ).scalars().all() assert [event.event_type for event in events] == [ "request_accepted", "worker_started", "cover_image_started", "cover_image_succeeded", "asset_generation_completed", ] async def test_cancel_queued_asset_generation_job_marks_it_canceled( db_session, auth_token, degraded_story_with_text, ): async def override_get_db(): yield db_session app.dependency_overrides[get_db] = override_get_db transport = ASGITransport(app=app) job = await create_generation_job( db_session, user_id=degraded_story_with_text.user_id, output_mode="asset_generation", input_type="image", request_payload={"story_id": degraded_story_with_text.id, "assets": ["image"]}, story_id=degraded_story_with_text.id, ) try: async with AsyncClient(transport=transport, base_url="http://test") as client: client.cookies.set("access_token", auth_token) response = await client.post(f"/api/generations/jobs/{job.id}/cancel") assert response.status_code == 200 data = response.json() assert data["status"] == "canceled" assert data["current_step"] == "generation_canceled" assert data["can_cancel"] is False assert data["can_retry"] is True finally: app.dependency_overrides.clear() async def test_retry_failed_asset_generation_job_requeues_new_worker_job( db_session, auth_token, degraded_story_with_text, ): async def override_get_db(): yield db_session app.dependency_overrides[get_db] = override_get_db transport = ASGITransport(app=app) task_delay_path = "app.tasks.generation_workflow.run_generation_workflow_task.delay" failed_job = await create_generation_job( db_session, user_id=degraded_story_with_text.user_id, output_mode="asset_generation", input_type="image", request_payload={"story_id": degraded_story_with_text.id, "assets": ["image"]}, story_id=degraded_story_with_text.id, ) await finish_generation_job( db_session, job=failed_job, story=degraded_story_with_text, status="failed", current_step="asset_generation_failed", error_message="cover timeout", message="Cover image generation failed.", ) try: with patch(task_delay_path) as mock_delay: async with AsyncClient(transport=transport, base_url="http://test") as client: client.cookies.set("access_token", auth_token) response = await client.post(f"/api/generations/jobs/{failed_job.id}/retry") assert response.status_code == 200 data = response.json() assert data["id"] != failed_job.id assert data["output_mode"] == "asset_generation" assert data["status"] == "running" assert data["current_step"] == "retry_queued" assert data["can_cancel"] is True assert data["can_retry"] is False mock_delay.assert_called_once_with(data["id"]) finally: app.dependency_overrides.clear() async def test_storybook_generation_is_queued_then_worker_records_page_image_events( db_session, auth_token, ): async def override_get_db(): yield db_session app.dependency_overrides[get_db] = override_get_db transport = ASGITransport(app=app) try: with patch( "app.services.story_service.generate_storybook", new_callable=AsyncMock, ) as mock_storybook: with patch( "app.services.story_service.generate_image", new_callable=AsyncMock, ) as mock_image: mock_storybook.return_value = build_storybook_output() mock_image.side_effect = [ "https://example.com/storybook-cover.png", "https://example.com/storybook-page-1.png", "https://example.com/storybook-page-2.png", ] task_delay_path = ( "app.tasks.generation_workflow.run_generation_workflow_task.delay" ) with patch(task_delay_path) as mock_delay: async with AsyncClient(transport=transport, base_url="http://test") as client: client.cookies.set("access_token", auth_token) response = await client.post( "/api/generations", json={ "output_mode": "storybook", "type": "keywords", "data": "森林, 发光, 友情", "page_count": 6, "generate_images": True, }, ) assert response.status_code == 202 data = response.json() assert data["id"] is None assert data["mode"] == "storybook" assert data["generation_status"] == "queued" assert data["text_status"] == "generating" mock_delay.assert_called_once_with(data["generation_job_id"]) job = ( await db_session.execute( select(GenerationJob).where( GenerationJob.id == data["generation_job_id"], ) ) ).scalar_one() await run_generation_job_service(job.id, db_session) job = ( await db_session.execute( select(GenerationJob).where( GenerationJob.id == job.id, ) ) ).scalar_one() assert job.story_id is not None assert job.status == "completed" events = ( await db_session.execute( select(GenerationJobEvent) .where(GenerationJobEvent.job_id == job.id) .order_by(GenerationJobEvent.id) ) ).scalars().all() assert [event.event_type for event in events] == [ "request_accepted", "worker_started", "context_prepared", "narrative_generated", "storybook_images_started", "storybook_cover_image_succeeded", "storybook_page_image_succeeded", "storybook_page_image_succeeded", "storybook_images_completed", "story_saved", "generation_completed", ] page_events = [ event for event in events if event.event_type == "storybook_page_image_succeeded" ] assert [event.event_metadata["page_number"] for event in page_events] == [1, 2] assert events[8].event_metadata["completed_pages"] == [1, 2] finally: app.dependency_overrides.clear() async def test_provider_call_events_record_latency_and_cost( db_session, test_user, ): from app.services import provider_router mock_result = StoryOutput( mode="generated", title="带供应商轨迹的故事", story_text="一只小鹿学会了复盘。", cover_prompt_suggestion="A deer with a golden bookmark", ) class MockAdapter: estimated_cost = 0.0123 def __init__(self, config): self.config = config async def execute(self, **kwargs): return mock_result job = await create_generation_job( db_session, user_id=test_user.id, output_mode="story", input_type="keywords", request_payload={"data": "小鹿"}, ) with patch.object( provider_router, "_get_providers_with_config", new_callable=AsyncMock, ) as mock_providers: mock_providers.return_value = [("demo", AdapterConfig(api_key=""), None)] with patch.object(provider_router.AdapterRegistry, "get", return_value=MockAdapter): result = await provider_router.generate_story_content( input_type="keywords", data="小鹿", db=db_session, generation_job=job, ) assert result == mock_result events = ( await db_session.execute( select(GenerationJobEvent) .where(GenerationJobEvent.job_id == job.id) .order_by(GenerationJobEvent.id) ) ).scalars().all() assert [event.event_type for event in events] == [ "request_accepted", "provider_call_started", "provider_call_succeeded", ] provider_event = events[2] assert provider_event.event_metadata["capability"] == "text" assert provider_event.event_metadata["adapter"] == "demo" assert provider_event.event_metadata["strategy"] == "priority" assert provider_event.event_metadata["latency_ms"] >= 0 assert provider_event.event_metadata["estimated_cost_usd"] == 0.0123 async def test_story_provider_stats_aggregate_job_events( db_session, auth_token, degraded_story_with_text, ): async def override_get_db(): yield db_session app.dependency_overrides[get_db] = override_get_db job = await create_generation_job( db_session, user_id=degraded_story_with_text.user_id, output_mode="asset_retry", input_type="image", request_payload={"assets": ["image"]}, story_id=degraded_story_with_text.id, ) await record_generation_event( db_session, job=job, story_id=degraded_story_with_text.id, event_type="provider_call_succeeded", status="succeeded", metadata={ "capability": "image", "adapter": "demo", "strategy": "priority", "latency_ms": 42, "estimated_cost_usd": 0.01, }, ) await record_generation_event( db_session, job=job, story_id=degraded_story_with_text.id, event_type="provider_call_failed", status="failed", metadata={ "capability": "image", "adapter": "cqtai", "strategy": "priority", "latency_ms": 120, "estimated_cost_usd": 0.02, "error": "timeout", }, ) transport = ASGITransport(app=app) try: async with AsyncClient(transport=transport, base_url="http://test") as client: client.cookies.set("access_token", auth_token) response = await client.get( f"/api/generations/{degraded_story_with_text.id}/provider-stats" ) assert response.status_code == 200 data = response.json() assert data["story_id"] == degraded_story_with_text.id assert data["total_calls"] == 2 assert data["successful_calls"] == 1 assert data["failed_calls"] == 1 assert data["avg_latency_ms"] == 81.0 assert data["estimated_cost_usd"] == 0.01 assert data["by_provider"] == [ { "capability": "image", "adapter": "cqtai", "call_count": 1, "success_count": 0, "failure_count": 1, "avg_latency_ms": 120.0, "estimated_cost_usd": 0.0, }, { "capability": "image", "adapter": "demo", "call_count": 1, "success_count": 1, "failure_count": 0, "avg_latency_ms": 42.0, "estimated_cost_usd": 0.01, }, ] finally: app.dependency_overrides.clear() async def test_user_provider_analytics_aggregate_across_stories( db_session, auth_token, degraded_story_with_text, test_story, ): async def override_get_db(): yield db_session app.dependency_overrides[get_db] = override_get_db image_job = await create_generation_job( db_session, user_id=degraded_story_with_text.user_id, output_mode="asset_retry", input_type="image", request_payload={"assets": ["image"]}, story_id=degraded_story_with_text.id, ) await record_generation_event( db_session, job=image_job, story_id=degraded_story_with_text.id, event_type="provider_call_succeeded", status="succeeded", metadata={ "capability": "image", "adapter": "demo", "strategy": "priority", "latency_ms": 42, "estimated_cost_usd": 0.01, }, ) await record_generation_event( db_session, job=image_job, story_id=degraded_story_with_text.id, event_type="provider_call_failed", status="failed", metadata={ "capability": "image", "adapter": "cqtai", "strategy": "priority", "latency_ms": 120, "error": "timeout", }, ) audio_job = await create_generation_job( db_session, user_id=test_story.user_id, output_mode="asset_retry", input_type="audio", request_payload={"assets": ["audio"]}, story_id=test_story.id, ) await record_generation_event( db_session, job=audio_job, story_id=test_story.id, event_type="provider_call_succeeded", status="succeeded", metadata={ "capability": "tts", "adapter": "edge_tts", "strategy": "priority", "latency_ms": 18, "estimated_cost_usd": 0.003, }, ) transport = ASGITransport(app=app) try: async with AsyncClient(transport=transport, base_url="http://test") as client: client.cookies.set("access_token", auth_token) response = await client.get("/api/generations/provider-analytics") assert response.status_code == 200 data = response.json() assert data["job_count"] == 2 assert data["story_count"] == 2 assert data["total_calls"] == 3 assert data["successful_calls"] == 2 assert data["failed_calls"] == 1 assert data["avg_latency_ms"] == 60.0 assert data["estimated_cost_usd"] == 0.013 assert data["failure_reasons"] == [{"reason": "timeout", "count": 1}] assert data["by_provider"] == [ { "capability": "image", "adapter": "cqtai", "call_count": 1, "success_count": 0, "failure_count": 1, "avg_latency_ms": 120.0, "estimated_cost_usd": 0.0, }, { "capability": "image", "adapter": "demo", "call_count": 1, "success_count": 1, "failure_count": 0, "avg_latency_ms": 42.0, "estimated_cost_usd": 0.01, }, { "capability": "tts", "adapter": "edge_tts", "call_count": 1, "success_count": 1, "failure_count": 0, "avg_latency_ms": 18.0, "estimated_cost_usd": 0.003, }, ] finally: app.dependency_overrides.clear() async def test_provider_analytics_support_days_and_capability_filters( db_session, auth_token, degraded_story_with_text, test_story, ): async def override_get_db(): yield db_session app.dependency_overrides[get_db] = override_get_db image_job = await create_generation_job( db_session, user_id=degraded_story_with_text.user_id, output_mode="asset_retry", input_type="image", request_payload={"assets": ["image"]}, story_id=degraded_story_with_text.id, ) old_event = await record_generation_event( db_session, job=image_job, story_id=degraded_story_with_text.id, event_type="provider_call_failed", status="failed", metadata={ "capability": "image", "adapter": "cqtai", "strategy": "priority", "latency_ms": 120, "error": "timeout", }, ) old_event.created_at = datetime.now(timezone.utc) - timedelta(days=10) await db_session.commit() tts_job = await create_generation_job( db_session, user_id=test_story.user_id, output_mode="asset_retry", input_type="audio", request_payload={"assets": ["audio"]}, story_id=test_story.id, ) await record_generation_event( db_session, job=tts_job, story_id=test_story.id, event_type="provider_call_succeeded", status="succeeded", metadata={ "capability": "tts", "adapter": "edge_tts", "strategy": "priority", "latency_ms": 18, "estimated_cost_usd": 0.003, }, ) transport = ASGITransport(app=app) try: async with AsyncClient(transport=transport, base_url="http://test") as client: client.cookies.set("access_token", auth_token) response = await client.get("/api/generations/provider-analytics?days=7") assert response.status_code == 200 data = response.json() assert data["window_days"] == 7 assert data["total_calls"] == 1 assert data["job_count"] == 1 assert data["story_count"] == 1 assert data["failure_reasons"] == [] response = await client.get( "/api/generations/provider-analytics?capability=image" ) assert response.status_code == 200 data = response.json() assert data["capability"] == "image" assert data["total_calls"] == 1 assert data["failed_calls"] == 1 assert data["job_count"] == 1 assert data["story_count"] == 1 assert data["failure_reasons"] == [{"reason": "timeout", "count": 1}] response = await client.get( f"/api/generations/{degraded_story_with_text.id}/provider-stats?capability=image" ) assert response.status_code == 200 data = response.json() assert data["capability"] == "image" assert data["failure_reasons"] == [{"reason": "timeout", "count": 1}] finally: app.dependency_overrides.clear() async def test_generation_ops_summary_exposes_running_stale_and_recent_failures( db_session, auth_token, degraded_story_with_text, test_story, ): async def override_get_db(): yield db_session app.dependency_overrides[get_db] = override_get_db running_job = await create_generation_job( db_session, user_id=test_story.user_id, output_mode="story", input_type="keywords", request_payload={"data": "星星"}, story_id=test_story.id, ) stale_job = await create_generation_job( db_session, user_id=degraded_story_with_text.user_id, output_mode="asset_generation", input_type="image", request_payload={"story_id": degraded_story_with_text.id}, story_id=degraded_story_with_text.id, ) failed_job = await create_generation_job( db_session, user_id=degraded_story_with_text.user_id, output_mode="asset_retry", input_type="image", request_payload={"assets": ["image"]}, story_id=degraded_story_with_text.id, ) degraded_job = await create_generation_job( db_session, user_id=test_story.user_id, output_mode="storybook", input_type="keywords", request_payload={"data": "月亮"}, story_id=test_story.id, ) stale_job.updated_at = datetime.now(timezone.utc) - timedelta(hours=3) failed_job.status = "failed" failed_job.current_step = "asset_retry_failed" failed_job.error_message = "image timeout" failed_job.updated_at = datetime.now(timezone.utc) - timedelta(hours=1) degraded_job.status = "degraded_completed" degraded_job.current_step = "generation_completed" degraded_job.updated_at = datetime.now(timezone.utc) - timedelta(minutes=30) running_job.updated_at = datetime.now(timezone.utc) - timedelta(minutes=10) await db_session.commit() transport = ASGITransport(app=app) try: async with AsyncClient(transport=transport, base_url="http://test") as client: client.cookies.set("access_token", auth_token) response = await client.get("/api/generations/ops-summary?hours=48") assert response.status_code == 200 data = response.json() assert data["window_hours"] == 48 assert data["active_jobs"] == 2 assert data["stale_running_jobs"] == 1 assert data["failed_jobs"] == 1 assert data["degraded_jobs"] == 1 assert data["asset_retry_jobs"] == 2 assert len(data["recent_failures"]) == 1 assert data["recent_failures"][0]["job_id"] == failed_job.id assert data["recent_failures"][0]["story_title"] == degraded_story_with_text.title assert data["recent_failures"][0]["failure_label"] == "资源重试失败" finally: app.dependency_overrides.clear() async def test_mark_stale_generation_jobs_marks_old_running_jobs_failed( db_session, degraded_story_with_text, ): stale_job = await create_generation_job( db_session, user_id=degraded_story_with_text.user_id, output_mode="story", input_type="keywords", request_payload={"data": "超时任务"}, story_id=degraded_story_with_text.id, ) stale_job.updated_at = datetime.now(timezone.utc) - timedelta(hours=2) await db_session.commit() result = await mark_stale_generation_jobs(db_session, stale_after_minutes=30) assert result == {"running": 1, "marked_stale": 1, "stale_after_minutes": 30} refreshed_job = ( await db_session.execute(select(GenerationJob).where(GenerationJob.id == stale_job.id)) ).scalar_one() assert refreshed_job.status == "failed" assert refreshed_job.current_step == "generation_stale_failed" assert refreshed_job.error_message == "Generation job exceeded 30 minutes without progress." events = ( await db_session.execute( select(GenerationJobEvent) .where(GenerationJobEvent.job_id == stale_job.id) .order_by(GenerationJobEvent.id) ) ).scalars().all() assert events[-1].event_type == "generation_stale_failed" assert events[-1].event_metadata["stale_after_minutes"] == 30 async def test_retry_assets_rejects_when_story_has_active_job( db_session, auth_token, degraded_story_with_text, ): async def override_get_db(): yield db_session app.dependency_overrides[get_db] = override_get_db await create_generation_job( db_session, user_id=degraded_story_with_text.user_id, output_mode="asset_generation", input_type="image", request_payload={"story_id": degraded_story_with_text.id}, story_id=degraded_story_with_text.id, ) transport = ASGITransport(app=app) try: async with AsyncClient(transport=transport, base_url="http://test") as client: client.cookies.set("access_token", auth_token) response = await client.post( f"/api/generations/{degraded_story_with_text.id}/retry-assets", json={"assets": ["image"]}, ) assert response.status_code == 409 assert "已有运行中的任务" in response.json()["detail"] finally: app.dependency_overrides.clear() async def test_cancel_queued_generation_job_marks_it_canceled( db_session, auth_token, test_user, ): async def override_get_db(): yield db_session app.dependency_overrides[get_db] = override_get_db transport = ASGITransport(app=app) task_delay_path = "app.tasks.generation_workflow.run_generation_workflow_task.delay" try: with patch(task_delay_path) as mock_delay: async with AsyncClient(transport=transport, base_url="http://test") as client: client.cookies.set("access_token", auth_token) response = await client.post( "/api/generations", json={ "output_mode": "story", "type": "keywords", "data": "小狐狸, 月亮船", "generate_images": False, }, ) assert response.status_code == 202 job_id = response.json()["generation_job_id"] mock_delay.assert_called_once_with(job_id) cancel_response = await client.post(f"/api/generations/jobs/{job_id}/cancel") assert cancel_response.status_code == 200 canceled_job = cancel_response.json() assert canceled_job["status"] == "canceled" assert canceled_job["current_step"] == "generation_canceled" assert canceled_job["can_cancel"] is False assert canceled_job["can_retry"] is True detail = await get_generation_job_detail( db_session, job_id=job_id, user_id=test_user.id, ) assert [event["event_type"] for event in detail["events"]] == [ "request_accepted", "generation_canceled", ] finally: app.dependency_overrides.clear() async def test_cancel_running_generation_job_marks_cancel_requested( db_session, auth_token, test_user, ): async def override_get_db(): yield db_session app.dependency_overrides[get_db] = override_get_db transport = ASGITransport(app=app) job = await create_generation_job( db_session, user_id=test_user.id, output_mode="story", input_type="keywords", request_payload={ "output_mode": "story", "type": "keywords", "data": "小熊, 森林", "generate_images": False, }, ) await record_generation_event( db_session, job=job, event_type="worker_started", status="running", message="Generation worker started processing the accepted request.", ) try: async with AsyncClient(transport=transport, base_url="http://test") as client: client.cookies.set("access_token", auth_token) response = await client.post(f"/api/generations/jobs/{job.id}/cancel") assert response.status_code == 200 data = response.json() assert data["status"] == "running" assert data["current_step"] == "cancel_requested" assert data["can_cancel"] is False assert data["can_retry"] is False refreshed_job = ( await db_session.execute(select(GenerationJob).where(GenerationJob.id == job.id)) ).scalar_one() assert refreshed_job.current_step == "cancel_requested" assert refreshed_job.error_message == "Cancellation requested by user." finally: app.dependency_overrides.clear() async def test_retry_failed_generation_job_requeues_new_worker_job( db_session, auth_token, test_user, mock_text_provider, ): async def override_get_db(): yield db_session app.dependency_overrides[get_db] = override_get_db transport = ASGITransport(app=app) task_delay_path = "app.tasks.generation_workflow.run_generation_workflow_task.delay" failed_job = await create_generation_job( db_session, user_id=test_user.id, output_mode="story", input_type="keywords", request_payload={ "output_mode": "story", "type": "keywords", "data": "小鹿, 星星", "generate_images": False, }, ) await finish_generation_job( db_session, job=failed_job, story=None, status="failed", current_step="generation_failed", error_message="upstream timeout", message="Generation failed before a durable story result was available.", ) try: with patch(task_delay_path) as mock_delay: async with AsyncClient(transport=transport, base_url="http://test") as client: client.cookies.set("access_token", auth_token) response = await client.post(f"/api/generations/jobs/{failed_job.id}/retry") assert response.status_code == 200 data = response.json() assert data["id"] != failed_job.id assert data["status"] == "running" assert data["current_step"] == "retry_queued" assert data["can_cancel"] is True assert data["can_retry"] is False mock_delay.assert_called_once_with(data["id"]) retried_job_id = data["id"] await run_generation_job_service(retried_job_id, db_session) retried_job = ( await db_session.execute( select(GenerationJob).where(GenerationJob.id == retried_job_id) ) ).scalar_one() assert retried_job.status == "completed" assert retried_job.current_step == "generation_completed" events = ( await db_session.execute( select(GenerationJobEvent) .where(GenerationJobEvent.job_id == retried_job_id) .order_by(GenerationJobEvent.id) ) ).scalars().all() assert [event.event_type for event in events[:3]] == [ "request_accepted", "retry_queued", "worker_started", ] finally: app.dependency_overrides.clear()