feat: move unified generation to background worker
This commit is contained in:
@@ -67,13 +67,13 @@ def _mark_legacy_generation_endpoint(response: Response, successor: str) -> None
|
||||
response.headers.update(_legacy_generation_headers(successor))
|
||||
|
||||
|
||||
@router.post("/generations", response_model=GenerationResponse)
|
||||
@router.post("/generations", response_model=GenerationResponse, status_code=202)
|
||||
async def create_generation(
|
||||
request: GenerationRequest,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Create a story or storybook through the unified generation workflow."""
|
||||
"""Accept one story/storybook generation request for background execution."""
|
||||
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
|
||||
return await story_service.generate_generation_service(request, user.id, db)
|
||||
|
||||
|
||||
@@ -117,9 +117,9 @@ class StorybookResponse(StoryStatusMixin):
|
||||
class GenerationResponse(StoryStatusMixin):
|
||||
"""Unified generation response for the target workflow API."""
|
||||
|
||||
id: int
|
||||
id: int | None = None
|
||||
generation_job_id: str | None = None
|
||||
title: str
|
||||
title: str | None = None
|
||||
mode: str
|
||||
story_text: str | None = None
|
||||
pages: list[StorybookPageResponse] | None = None
|
||||
|
||||
@@ -6,7 +6,7 @@ from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import desc, select
|
||||
from sqlalchemy import desc, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
@@ -59,6 +59,7 @@ def _job_progress(job: GenerationJob) -> dict[str, Any]:
|
||||
|
||||
progress_map: dict[str, tuple[int, str]] = {
|
||||
"request_accepted": (5, "已接收请求"),
|
||||
"worker_started": (12, "后台任务已开始"),
|
||||
"context_prepared": (20, "上下文已准备"),
|
||||
"narrative_generated": (45, "正文已生成"),
|
||||
"story_saved": (60, "主记录已保存"),
|
||||
@@ -66,8 +67,18 @@ def _job_progress(job: GenerationJob) -> dict[str, Any]:
|
||||
"provider_call_succeeded": (72, "Provider 调用成功"),
|
||||
"provider_call_failed": (72, "Provider 调用失败,尝试恢复"),
|
||||
"cover_image_started": (75, "封面生成中"),
|
||||
"cover_image_succeeded": (88, "封面已生成"),
|
||||
"cover_image_failed": (88, "封面生成失败"),
|
||||
"storybook_images_started": (75, "绘本插图生成中"),
|
||||
"storybook_cover_image_succeeded": (82, "绘本封面已生成"),
|
||||
"storybook_cover_image_failed": (82, "绘本封面生成失败"),
|
||||
"storybook_page_image_succeeded": (86, "分页插图已生成"),
|
||||
"storybook_page_image_failed": (86, "分页插图生成失败"),
|
||||
"storybook_images_completed": (92, "绘本插图已完成"),
|
||||
"audio_started": (75, "音频生成中"),
|
||||
"audio_cache_hit": (88, "音频缓存已复用"),
|
||||
"audio_succeeded": (88, "音频已生成"),
|
||||
"audio_failed": (88, "音频生成失败"),
|
||||
"asset_retry_started": (25, "资源重试中"),
|
||||
"postprocessing_queued": (90, "后处理已排队"),
|
||||
"asset_generation_completed": (100, "资源已完成"),
|
||||
@@ -155,6 +166,10 @@ async def record_generation_event(
|
||||
) -> GenerationJobEvent:
|
||||
"""Append one event to an existing generation job."""
|
||||
|
||||
job.current_step = event_type
|
||||
if story_id is not None:
|
||||
job.story_id = story_id
|
||||
|
||||
event = GenerationJobEvent(
|
||||
job_id=job.id,
|
||||
story_id=story_id if story_id is not None else job.story_id,
|
||||
@@ -169,6 +184,42 @@ async def record_generation_event(
|
||||
return event
|
||||
|
||||
|
||||
async def claim_generation_job_for_worker(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
job_id: str,
|
||||
) -> GenerationJob | None:
|
||||
"""Claim one queued generation job for worker execution once."""
|
||||
|
||||
claim_result = await db.execute(
|
||||
update(GenerationJob)
|
||||
.where(
|
||||
GenerationJob.id == job_id,
|
||||
GenerationJob.status == "running",
|
||||
GenerationJob.current_step == "request_accepted",
|
||||
)
|
||||
.values(current_step="worker_started")
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
if not claim_result.rowcount:
|
||||
return None
|
||||
|
||||
result = await db.execute(select(GenerationJob).where(GenerationJob.id == job_id))
|
||||
job = result.scalar_one_or_none()
|
||||
if job is None:
|
||||
return None
|
||||
|
||||
await record_generation_event(
|
||||
db,
|
||||
job=job,
|
||||
event_type="worker_started",
|
||||
status="running",
|
||||
message="Generation worker started processing the accepted request.",
|
||||
)
|
||||
return job
|
||||
|
||||
|
||||
async def finish_generation_job(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
|
||||
@@ -12,7 +12,7 @@ from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
from app.db.models import ChildProfile, Story, StoryUniverse
|
||||
from app.db.models import ChildProfile, GenerationJob, Story, StoryUniverse
|
||||
from app.schemas.story_schemas import (
|
||||
AchievementItem,
|
||||
FullStoryResponse,
|
||||
@@ -33,6 +33,7 @@ from app.services.audio_storage import (
|
||||
write_story_audio_cache,
|
||||
)
|
||||
from app.services.generation_jobs import (
|
||||
claim_generation_job_for_worker,
|
||||
create_generation_job,
|
||||
ensure_no_active_story_generation_job,
|
||||
finish_generation_job,
|
||||
@@ -1113,7 +1114,7 @@ async def generate_generation_service(
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
) -> GenerationResponse:
|
||||
"""Unified generation workflow entry point for stories and storybooks."""
|
||||
"""Queue one unified generation workflow for background execution."""
|
||||
|
||||
job = await create_generation_job(
|
||||
db,
|
||||
@@ -1124,7 +1125,65 @@ async def generate_generation_service(
|
||||
)
|
||||
|
||||
try:
|
||||
response = await _generate_generation_service_with_job(request, user_id, db, job=job)
|
||||
from app.tasks.generation_workflow import run_generation_workflow_task
|
||||
|
||||
run_generation_workflow_task.delay(job.id)
|
||||
except Exception as exc:
|
||||
await finish_generation_job(
|
||||
db,
|
||||
job=job,
|
||||
story=None,
|
||||
status="failed",
|
||||
current_step="generation_failed",
|
||||
error_message="Background generation dispatch failed.",
|
||||
message="Generation failed before the worker could start processing the job.",
|
||||
metadata={"dispatch_error": str(exc)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="后台生成任务派发失败,请确认 worker 可用后重试。",
|
||||
) from exc
|
||||
|
||||
return _build_queued_generation_response(request, job_id=job.id)
|
||||
|
||||
|
||||
def _build_queued_generation_response(
|
||||
request: GenerationRequest,
|
||||
*,
|
||||
job_id: str,
|
||||
) -> GenerationResponse:
|
||||
"""Build the immediate API response after a generation job is accepted."""
|
||||
|
||||
return GenerationResponse(
|
||||
id=None,
|
||||
generation_job_id=job_id,
|
||||
title="生成任务已提交",
|
||||
mode="storybook" if request.output_mode == "storybook" else "generated",
|
||||
generation_status="queued",
|
||||
text_status="generating",
|
||||
image_status="not_requested",
|
||||
audio_status="not_requested",
|
||||
last_error=None,
|
||||
retryable_assets=[],
|
||||
child_profile_id=request.child_profile_id,
|
||||
universe_id=request.universe_id,
|
||||
)
|
||||
|
||||
|
||||
async def execute_generation_job_service(
|
||||
job: GenerationJob,
|
||||
db: AsyncSession,
|
||||
) -> GenerationResponse:
|
||||
"""Execute one previously accepted generation job inside the worker."""
|
||||
|
||||
try:
|
||||
request = GenerationRequest.model_validate(job.request_payload or {})
|
||||
response = await _generate_generation_service_with_job(
|
||||
request,
|
||||
job.user_id,
|
||||
db,
|
||||
job=job,
|
||||
)
|
||||
except HTTPException as exc:
|
||||
await finish_generation_job(
|
||||
db,
|
||||
@@ -1151,6 +1210,21 @@ async def generate_generation_service(
|
||||
return response
|
||||
|
||||
|
||||
async def run_generation_job_service(
|
||||
job_id: str,
|
||||
db: AsyncSession,
|
||||
) -> GenerationJob | None:
|
||||
"""Claim and execute one generation job from the background queue."""
|
||||
|
||||
job = await claim_generation_job_for_worker(db, job_id=job_id)
|
||||
if job is None:
|
||||
logger.info("generation_job_execution_skipped", job_id=job_id)
|
||||
return None
|
||||
|
||||
await execute_generation_job_service(job, db)
|
||||
return job
|
||||
|
||||
|
||||
async def _generate_generation_service_with_job(
|
||||
request: GenerationRequest,
|
||||
user_id: str,
|
||||
|
||||
38
backend/app/tasks/generation_workflow.py
Normal file
38
backend/app/tasks/generation_workflow.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Background execution for unified generation workflows."""
|
||||
|
||||
import asyncio
|
||||
|
||||
from app.core.celery_app import celery_app
|
||||
from app.core.logging import get_logger
|
||||
from app.db.database import _get_session_factory
|
||||
from app.services.story_service import run_generation_job_service
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@celery_app.task
|
||||
def run_generation_workflow_task(job_id: str):
|
||||
"""Execute one accepted generation job in the Celery worker."""
|
||||
|
||||
logger.info("generation_workflow_task_started", job_id=job_id)
|
||||
|
||||
async def _run():
|
||||
session_factory = _get_session_factory()
|
||||
async with session_factory() as session:
|
||||
return await run_generation_job_service(job_id, session)
|
||||
|
||||
try:
|
||||
result = asyncio.run(_run())
|
||||
logger.info(
|
||||
"generation_workflow_task_completed",
|
||||
job_id=job_id,
|
||||
executed=bool(result),
|
||||
)
|
||||
return {"job_id": job_id, "executed": bool(result)}
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"generation_workflow_task_failed",
|
||||
job_id=job_id,
|
||||
error=str(exc),
|
||||
)
|
||||
raise
|
||||
@@ -18,6 +18,7 @@ from app.services.generation_jobs import (
|
||||
mark_stale_generation_jobs,
|
||||
record_generation_event,
|
||||
)
|
||||
from app.services.story_service import run_generation_job_service
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
@@ -45,7 +46,7 @@ def build_storybook_output() -> Storybook:
|
||||
)
|
||||
|
||||
|
||||
async def test_unified_generation_records_job_events_and_retryable_assets(
|
||||
async def test_unified_generation_is_queued_then_worker_persists_story_and_events(
|
||||
db_session,
|
||||
test_user,
|
||||
auth_token,
|
||||
@@ -56,82 +57,108 @@ async def test_unified_generation_records_job_events_and_retryable_assets(
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
transport = ASGITransport(app=app)
|
||||
task_delay_path = (
|
||||
"app.tasks.generation_workflow.run_generation_workflow_task.delay"
|
||||
)
|
||||
|
||||
try:
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
client.cookies.set("access_token", auth_token)
|
||||
with patch(task_delay_path) as mock_delay:
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
client.cookies.set("access_token", auth_token)
|
||||
|
||||
response = await client.post(
|
||||
"/api/generations",
|
||||
json={
|
||||
"output_mode": "story",
|
||||
"type": "keywords",
|
||||
"data": "小兔子, 森林",
|
||||
"generate_images": False,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["generation_status"] == "partial_ready"
|
||||
assert data["retryable_assets"] == ["image", "audio"]
|
||||
assert data["generation_job_id"]
|
||||
|
||||
jobs = (
|
||||
await db_session.execute(
|
||||
select(GenerationJob).where(GenerationJob.user_id == test_user.id)
|
||||
response = await client.post(
|
||||
"/api/generations",
|
||||
json={
|
||||
"output_mode": "story",
|
||||
"type": "keywords",
|
||||
"data": "小兔子, 森林",
|
||||
"generate_images": False,
|
||||
},
|
||||
)
|
||||
).scalars().all()
|
||||
assert len(jobs) == 1
|
||||
job = jobs[0]
|
||||
assert job.story_id == data["id"]
|
||||
assert job.output_mode == "story"
|
||||
assert job.input_type == "keywords"
|
||||
assert job.status == "completed"
|
||||
assert job.current_step == "generation_completed"
|
||||
assert job.result_snapshot["retryable_assets"] == ["image", "audio"]
|
||||
assert data["generation_job_id"] == job.id
|
||||
|
||||
events = (
|
||||
await db_session.execute(
|
||||
select(GenerationJobEvent)
|
||||
.where(GenerationJobEvent.job_id == job.id)
|
||||
.order_by(GenerationJobEvent.id)
|
||||
)
|
||||
).scalars().all()
|
||||
assert [event.event_type for event in events] == [
|
||||
"request_accepted",
|
||||
"context_prepared",
|
||||
"narrative_generated",
|
||||
"story_saved",
|
||||
"generation_completed",
|
||||
]
|
||||
assert events[1].event_metadata["has_memory_context"] is False
|
||||
assert events[2].event_metadata["title"] == "小兔子的冒险"
|
||||
assert events[3].story_id == data["id"]
|
||||
assert response.status_code == 202
|
||||
data = response.json()
|
||||
assert data["id"] is None
|
||||
assert data["generation_status"] == "queued"
|
||||
assert data["text_status"] == "generating"
|
||||
assert data["retryable_assets"] == []
|
||||
assert data["generation_job_id"]
|
||||
mock_delay.assert_called_once_with(data["generation_job_id"])
|
||||
|
||||
detail_response = await client.get(f"/api/generations/jobs/{job.id}")
|
||||
assert detail_response.status_code == 200
|
||||
detail = detail_response.json()
|
||||
assert detail["id"] == job.id
|
||||
assert detail["story_id"] == data["id"]
|
||||
assert detail["progress_percent"] == 100
|
||||
assert detail["progress_label"] == "已完成"
|
||||
assert detail["is_terminal"] is True
|
||||
assert [event["event_type"] for event in detail["events"]] == [
|
||||
"request_accepted",
|
||||
"context_prepared",
|
||||
"narrative_generated",
|
||||
"story_saved",
|
||||
"generation_completed",
|
||||
]
|
||||
jobs = (
|
||||
await db_session.execute(
|
||||
select(GenerationJob).where(GenerationJob.user_id == test_user.id)
|
||||
)
|
||||
).scalars().all()
|
||||
assert len(jobs) == 1
|
||||
job = jobs[0]
|
||||
assert job.story_id is None
|
||||
assert job.output_mode == "story"
|
||||
assert job.input_type == "keywords"
|
||||
assert job.status == "running"
|
||||
assert job.current_step == "request_accepted"
|
||||
assert data["generation_job_id"] == job.id
|
||||
|
||||
list_response = await client.get(f"/api/generations/{data['id']}/jobs")
|
||||
assert list_response.status_code == 200
|
||||
job_list = list_response.json()
|
||||
assert [item["id"] for item in job_list] == [job.id]
|
||||
assert job_list[0]["progress_percent"] == 100
|
||||
assert job_list[0]["is_terminal"] is True
|
||||
await run_generation_job_service(job.id, db_session)
|
||||
|
||||
job = (
|
||||
await db_session.execute(
|
||||
select(GenerationJob).where(GenerationJob.id == job.id)
|
||||
)
|
||||
).scalar_one()
|
||||
assert job.story_id is not None
|
||||
assert job.status == "completed"
|
||||
assert job.current_step == "generation_completed"
|
||||
assert job.result_snapshot["retryable_assets"] == ["image", "audio"]
|
||||
|
||||
events = (
|
||||
await db_session.execute(
|
||||
select(GenerationJobEvent)
|
||||
.where(GenerationJobEvent.job_id == job.id)
|
||||
.order_by(GenerationJobEvent.id)
|
||||
)
|
||||
).scalars().all()
|
||||
assert [event.event_type for event in events] == [
|
||||
"request_accepted",
|
||||
"worker_started",
|
||||
"context_prepared",
|
||||
"narrative_generated",
|
||||
"story_saved",
|
||||
"generation_completed",
|
||||
]
|
||||
assert events[2].event_metadata["has_memory_context"] is False
|
||||
assert events[3].event_metadata["title"] == "小兔子的冒险"
|
||||
assert events[4].story_id == job.story_id
|
||||
|
||||
detail_response = await client.get(f"/api/generations/jobs/{job.id}")
|
||||
assert detail_response.status_code == 200
|
||||
detail = detail_response.json()
|
||||
assert detail["id"] == job.id
|
||||
assert detail["story_id"] == job.story_id
|
||||
assert detail["progress_percent"] == 100
|
||||
assert detail["progress_label"] == "已完成"
|
||||
assert detail["is_terminal"] is True
|
||||
assert [event["event_type"] for event in detail["events"]] == [
|
||||
"request_accepted",
|
||||
"worker_started",
|
||||
"context_prepared",
|
||||
"narrative_generated",
|
||||
"story_saved",
|
||||
"generation_completed",
|
||||
]
|
||||
|
||||
story_response = await client.get(f"/api/generations/{job.story_id}")
|
||||
assert story_response.status_code == 200
|
||||
story_data = story_response.json()
|
||||
assert story_data["generation_status"] == "partial_ready"
|
||||
assert story_data["retryable_assets"] == ["image", "audio"]
|
||||
|
||||
list_response = await client.get(f"/api/generations/{job.story_id}/jobs")
|
||||
assert list_response.status_code == 200
|
||||
job_list = list_response.json()
|
||||
assert [item["id"] for item in job_list] == [job.id]
|
||||
assert job_list[0]["progress_percent"] == 100
|
||||
assert job_list[0]["is_terminal"] is True
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
@@ -196,7 +223,7 @@ async def test_asset_retry_records_job_events_and_updates_retryable_assets(
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
async def test_storybook_generation_records_page_image_events(
|
||||
async def test_storybook_generation_is_queued_then_worker_records_page_image_events(
|
||||
db_session,
|
||||
auth_token,
|
||||
):
|
||||
@@ -222,61 +249,78 @@ async def test_storybook_generation_records_page_image_events(
|
||||
"https://example.com/storybook-page-2.png",
|
||||
]
|
||||
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
client.cookies.set("access_token", auth_token)
|
||||
|
||||
response = await client.post(
|
||||
"/api/generations",
|
||||
json={
|
||||
"output_mode": "storybook",
|
||||
"type": "keywords",
|
||||
"data": "森林, 发光, 友情",
|
||||
"page_count": 6,
|
||||
"generate_images": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["mode"] == "storybook"
|
||||
assert data["image_status"] == "ready"
|
||||
|
||||
job = (
|
||||
await db_session.execute(
|
||||
select(GenerationJob).where(
|
||||
GenerationJob.story_id == data["id"],
|
||||
GenerationJob.output_mode == "storybook",
|
||||
task_delay_path = (
|
||||
"app.tasks.generation_workflow.run_generation_workflow_task.delay"
|
||||
)
|
||||
)
|
||||
).scalar_one()
|
||||
with patch(task_delay_path) as mock_delay:
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
client.cookies.set("access_token", auth_token)
|
||||
|
||||
events = (
|
||||
await db_session.execute(
|
||||
select(GenerationJobEvent)
|
||||
.where(GenerationJobEvent.job_id == job.id)
|
||||
.order_by(GenerationJobEvent.id)
|
||||
)
|
||||
).scalars().all()
|
||||
response = await client.post(
|
||||
"/api/generations",
|
||||
json={
|
||||
"output_mode": "storybook",
|
||||
"type": "keywords",
|
||||
"data": "森林, 发光, 友情",
|
||||
"page_count": 6,
|
||||
"generate_images": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert [event.event_type for event in events] == [
|
||||
"request_accepted",
|
||||
"context_prepared",
|
||||
"narrative_generated",
|
||||
"storybook_images_started",
|
||||
"storybook_cover_image_succeeded",
|
||||
"storybook_page_image_succeeded",
|
||||
"storybook_page_image_succeeded",
|
||||
"storybook_images_completed",
|
||||
"story_saved",
|
||||
"generation_completed",
|
||||
]
|
||||
page_events = [
|
||||
event
|
||||
for event in events
|
||||
if event.event_type == "storybook_page_image_succeeded"
|
||||
]
|
||||
assert [event.event_metadata["page_number"] for event in page_events] == [1, 2]
|
||||
assert events[7].event_metadata["completed_pages"] == [1, 2]
|
||||
assert response.status_code == 202
|
||||
data = response.json()
|
||||
assert data["id"] is None
|
||||
assert data["mode"] == "storybook"
|
||||
assert data["generation_status"] == "queued"
|
||||
assert data["text_status"] == "generating"
|
||||
mock_delay.assert_called_once_with(data["generation_job_id"])
|
||||
job = (
|
||||
await db_session.execute(
|
||||
select(GenerationJob).where(
|
||||
GenerationJob.id == data["generation_job_id"],
|
||||
)
|
||||
)
|
||||
).scalar_one()
|
||||
await run_generation_job_service(job.id, db_session)
|
||||
|
||||
job = (
|
||||
await db_session.execute(
|
||||
select(GenerationJob).where(
|
||||
GenerationJob.id == job.id,
|
||||
)
|
||||
)
|
||||
).scalar_one()
|
||||
assert job.story_id is not None
|
||||
assert job.status == "completed"
|
||||
|
||||
events = (
|
||||
await db_session.execute(
|
||||
select(GenerationJobEvent)
|
||||
.where(GenerationJobEvent.job_id == job.id)
|
||||
.order_by(GenerationJobEvent.id)
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
assert [event.event_type for event in events] == [
|
||||
"request_accepted",
|
||||
"worker_started",
|
||||
"context_prepared",
|
||||
"narrative_generated",
|
||||
"storybook_images_started",
|
||||
"storybook_cover_image_succeeded",
|
||||
"storybook_page_image_succeeded",
|
||||
"storybook_page_image_succeeded",
|
||||
"storybook_images_completed",
|
||||
"story_saved",
|
||||
"generation_completed",
|
||||
]
|
||||
page_events = [
|
||||
event
|
||||
for event in events
|
||||
if event.event_type == "storybook_page_image_succeeded"
|
||||
]
|
||||
assert [event.event_metadata["page_number"] for event in page_events] == [1, 2]
|
||||
assert events[8].event_metadata["completed_pages"] == [1, 2]
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@@ -430,6 +430,8 @@ class TestGenerateFull:
|
||||
class TestUnifiedGenerations:
|
||||
"""Tests for the target unified generation API."""
|
||||
|
||||
TASK_DELAY_PATH = "app.tasks.generation_workflow.run_generation_workflow_task.delay"
|
||||
|
||||
def test_create_generation_without_auth(self, client: TestClient):
|
||||
response = client.post(
|
||||
"/api/generations",
|
||||
@@ -443,60 +445,64 @@ class TestUnifiedGenerations:
|
||||
mock_text_provider,
|
||||
mock_image_provider,
|
||||
):
|
||||
response = auth_client.post(
|
||||
"/api/generations",
|
||||
json={
|
||||
"output_mode": "story",
|
||||
"type": "keywords",
|
||||
"data": "小兔子, 森林, 勇气",
|
||||
"generate_images": True,
|
||||
},
|
||||
)
|
||||
with patch(self.TASK_DELAY_PATH) as mock_delay:
|
||||
response = auth_client.post(
|
||||
"/api/generations",
|
||||
json={
|
||||
"output_mode": "story",
|
||||
"type": "keywords",
|
||||
"data": "小兔子, 森林, 勇气",
|
||||
"generate_images": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.status_code == 202
|
||||
assert "Deprecation" not in response.headers
|
||||
data = response.json()
|
||||
assert data["id"] is not None
|
||||
assert data["id"] is None
|
||||
assert data["mode"] == "generated"
|
||||
assert data["story_text"] == "从前有一只小兔子。"
|
||||
assert data["image_url"] == "https://example.com/image.png"
|
||||
assert data["cover_url"] == "https://example.com/image.png"
|
||||
assert data["story_text"] is None
|
||||
assert data["image_url"] is None
|
||||
assert data["cover_url"] is None
|
||||
assert data["pages"] is None
|
||||
assert data["generation_status"] == "partial_ready"
|
||||
assert data["image_status"] == "ready"
|
||||
assert data["generation_status"] == "queued"
|
||||
assert data["text_status"] == "generating"
|
||||
assert data["image_status"] == "not_requested"
|
||||
assert data["audio_status"] == "not_requested"
|
||||
assert data["errors"] == {}
|
||||
mock_delay.assert_called_once_with(data["generation_job_id"])
|
||||
|
||||
def test_create_story_generation_without_assets(
|
||||
self,
|
||||
auth_client: TestClient,
|
||||
mock_text_provider,
|
||||
):
|
||||
response = auth_client.post(
|
||||
"/api/generations",
|
||||
json={
|
||||
"output_mode": "story",
|
||||
"type": "keywords",
|
||||
"data": "小兔子, 森林",
|
||||
"generate_images": False,
|
||||
},
|
||||
)
|
||||
with patch(self.TASK_DELAY_PATH) as mock_delay:
|
||||
response = auth_client.post(
|
||||
"/api/generations",
|
||||
json={
|
||||
"output_mode": "story",
|
||||
"type": "keywords",
|
||||
"data": "小兔子, 森林",
|
||||
"generate_images": False,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.status_code == 202
|
||||
data = response.json()
|
||||
assert data["mode"] == "generated"
|
||||
assert data["image_url"] is None
|
||||
assert data["generation_status"] == "partial_ready"
|
||||
assert data["generation_status"] == "queued"
|
||||
assert data["text_status"] == "generating"
|
||||
assert data["image_status"] == "not_requested"
|
||||
mock_delay.assert_called_once_with(data["generation_job_id"])
|
||||
|
||||
def test_create_story_generation_image_failure(
|
||||
self,
|
||||
auth_client: TestClient,
|
||||
mock_text_provider,
|
||||
):
|
||||
with patch("app.services.story_service.generate_image", new_callable=AsyncMock) as mock_img:
|
||||
mock_img.side_effect = Exception("Image API error")
|
||||
|
||||
with patch(self.TASK_DELAY_PATH) as mock_delay:
|
||||
response = auth_client.post(
|
||||
"/api/generations",
|
||||
json={
|
||||
@@ -507,55 +513,45 @@ class TestUnifiedGenerations:
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.status_code == 202
|
||||
data = response.json()
|
||||
assert data["image_url"] is None
|
||||
assert data["generation_status"] == "degraded_completed"
|
||||
assert data["image_status"] == "failed"
|
||||
assert data["generation_status"] == "queued"
|
||||
assert data["text_status"] == "generating"
|
||||
assert data["image_status"] == "not_requested"
|
||||
assert data["audio_status"] == "not_requested"
|
||||
assert "Image API error" in data["errors"]["image"]
|
||||
assert "Image API error" in data["last_error"]
|
||||
assert data["errors"] == {}
|
||||
assert data["last_error"] is None
|
||||
mock_delay.assert_called_once_with(data["generation_job_id"])
|
||||
|
||||
def test_create_storybook_generation_success(self, auth_client: TestClient):
|
||||
with patch(
|
||||
"app.services.story_service.generate_storybook",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_storybook:
|
||||
with patch(
|
||||
"app.services.story_service.generate_image",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_image:
|
||||
mock_storybook.return_value = build_storybook_output()
|
||||
mock_image.side_effect = [
|
||||
"https://example.com/storybook-cover.png",
|
||||
"https://example.com/storybook-page-1.png",
|
||||
"https://example.com/storybook-page-2.png",
|
||||
]
|
||||
with patch(self.TASK_DELAY_PATH) as mock_delay:
|
||||
response = auth_client.post(
|
||||
"/api/generations",
|
||||
json={
|
||||
"output_mode": "storybook",
|
||||
"type": "keywords",
|
||||
"data": "森林, 发光, 友情",
|
||||
"page_count": 6,
|
||||
"generate_images": True,
|
||||
},
|
||||
)
|
||||
|
||||
response = auth_client.post(
|
||||
"/api/generations",
|
||||
json={
|
||||
"output_mode": "storybook",
|
||||
"type": "keywords",
|
||||
"data": "森林, 发光, 友情",
|
||||
"page_count": 6,
|
||||
"generate_images": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.status_code == 202
|
||||
data = response.json()
|
||||
assert data["id"] is not None
|
||||
assert data["id"] is None
|
||||
assert data["mode"] == "storybook"
|
||||
assert data["story_text"] is None
|
||||
assert len(data["pages"]) == 2
|
||||
assert data["cover_url"] == "https://example.com/storybook-cover.png"
|
||||
assert data["image_url"] == "https://example.com/storybook-cover.png"
|
||||
assert data["main_character"] == "小兔子露露"
|
||||
assert data["art_style"] == "温暖水彩"
|
||||
assert data["generation_status"] == "completed"
|
||||
assert data["image_status"] == "ready"
|
||||
assert data["pages"] is None
|
||||
assert data["cover_url"] is None
|
||||
assert data["image_url"] is None
|
||||
assert data["main_character"] is None
|
||||
assert data["art_style"] is None
|
||||
assert data["generation_status"] == "queued"
|
||||
assert data["text_status"] == "generating"
|
||||
assert data["image_status"] == "not_requested"
|
||||
assert data["audio_status"] == "not_requested"
|
||||
mock_delay.assert_called_once_with(data["generation_job_id"])
|
||||
|
||||
def test_get_generation_alias(self, auth_client: TestClient, test_story):
|
||||
response = auth_client.get(f"/api/generations/{test_story.id}")
|
||||
|
||||
Reference in New Issue
Block a user