feat: add generation trace and partial-ready workflow status

This commit is contained in:
2026-04-18 21:53:55 +08:00
parent 96dfc677e2
commit e99a7fbe14
36 changed files with 2597 additions and 144 deletions

View File

@@ -0,0 +1,107 @@
"""add story text status and partial ready semantics
Revision ID: 0012_story_text_status
Revises: 0011_add_generation_jobs
Create Date: 2026-04-18
"""
import sqlalchemy as sa
from alembic import op
revision = "0012_story_text_status"
down_revision = "0011_add_generation_jobs"
branch_labels = None
depends_on = None
stories = sa.table(
"stories",
sa.column("id", sa.Integer),
sa.column("story_text", sa.Text),
sa.column("pages", sa.JSON),
sa.column("cover_prompt", sa.Text),
sa.column("image_url", sa.String(length=500)),
sa.column("generation_status", sa.String(length=32)),
sa.column("text_status", sa.String(length=32)),
sa.column("image_status", sa.String(length=32)),
sa.column("audio_status", sa.String(length=32)),
)
def _has_narrative(row: dict) -> bool:
return bool(row.get("story_text")) or bool(row.get("pages"))
def _has_pending_image(row: dict) -> bool:
if row.get("image_status") in {"ready", "generating"}:
return False
pages = row.get("pages") or []
has_missing_page_image = any(
isinstance(page, dict)
and page.get("image_prompt")
and not page.get("image_url")
for page in pages
)
return bool(row.get("cover_prompt") and not row.get("image_url")) or has_missing_page_image
def _resolve_generation_status(row: dict) -> str:
if not _has_narrative(row):
return "failed"
image_status = row.get("image_status") or "not_requested"
audio_status = row.get("audio_status") or "not_requested"
if "generating" in {image_status, audio_status}:
return "assets_generating"
if "failed" in {image_status, audio_status}:
return "degraded_completed"
has_pending_audio = bool(row.get("story_text")) and audio_status not in {
"ready",
"generating",
}
if _has_pending_image(row) or has_pending_audio:
return "partial_ready"
if image_status == "not_requested" and audio_status == "not_requested":
return "narrative_ready"
return "completed"
def upgrade() -> None:
op.add_column(
"stories",
sa.Column("text_status", sa.String(length=32), nullable=False, server_default="ready"),
)
connection = op.get_bind()
rows = connection.execute(
sa.select(
stories.c.id,
stories.c.story_text,
stories.c.pages,
stories.c.cover_prompt,
stories.c.image_url,
stories.c.image_status,
stories.c.audio_status,
)
).mappings()
for row in rows:
text_status = "ready" if _has_narrative(row) else "failed"
generation_status = _resolve_generation_status(row)
connection.execute(
stories.update()
.where(stories.c.id == row["id"])
.values(text_status=text_status, generation_status=generation_status)
)
def downgrade() -> None:
op.drop_column("stories", "text_status")

View File

@@ -17,6 +17,9 @@ from app.schemas.story_schemas import (
AchievementItem,
FullStoryResponse,
GenerateRequest,
GenerationJobDetailResponse,
GenerationJobSummaryResponse,
GenerationProviderStatsResponse,
GenerationRequest,
GenerationResponse,
StoryAssetRetryRequest,
@@ -28,6 +31,11 @@ from app.schemas.story_schemas import (
StoryResponse,
)
from app.services import story_service
from app.services.generation_jobs import (
get_generation_job_detail,
get_story_provider_stats,
list_story_generation_jobs,
)
from app.services.memory_service import build_enhanced_memory_context
from app.services.provider_router import (
generate_image,
@@ -65,6 +73,42 @@ async def create_generation(
return await story_service.generate_generation_service(request, user.id, db)
@router.get("/generations/jobs/{job_id}", response_model=GenerationJobDetailResponse)
async def get_generation_job(
job_id: str,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get one generation job with ordered workflow events."""
return await get_generation_job_detail(db, job_id=job_id, user_id=user.id)
@router.get(
"/generations/{story_id}/jobs",
response_model=list[GenerationJobSummaryResponse],
)
async def list_generation_jobs(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""List recent generation jobs for a generated story/storybook."""
return await list_story_generation_jobs(db, story_id=story_id, user_id=user.id)
@router.get(
"/generations/{story_id}/provider-stats",
response_model=GenerationProviderStatsResponse,
)
async def get_generation_provider_stats(
story_id: int,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get provider call stats aggregated from generation job events."""
return await get_story_provider_stats(db, story_id=story_id, user_id=user.id)
@router.get("/generations/{story_id}", response_model=StoryDetailResponse)
async def get_generation(
story_id: int,
@@ -135,13 +179,14 @@ async def generate_story_stream(
# Step 1: Generate Content
try:
result = await generate_story_content(
input_type=request.type,
data=request.data,
education_theme=request.education_theme,
memory_context=memory_context,
db=db,
)
result = await generate_story_content(
input_type=request.type,
data=request.data,
education_theme=request.education_theme,
memory_context=memory_context,
user_id=user.id,
db=db,
)
except Exception as e:
logger.error("sse_story_generation_failed", error=str(e))
yield {"event": "story_failed", "data": json.dumps({"error": str(e)})}
@@ -163,6 +208,7 @@ async def generate_story_stream(
"child_profile_id": story.child_profile_id,
"universe_id": story.universe_id,
"generation_status": story.generation_status,
"text_status": story.text_status,
"image_status": story.image_status,
"audio_status": story.audio_status,
"last_error": story.last_error,
@@ -175,7 +221,12 @@ async def generate_story_stream(
await db.commit()
try:
# Direct call to provider router's generate_image, sharing db session
image_url = await generate_image(story.cover_prompt, db=db)
image_url = await generate_image(
story.cover_prompt,
db=db,
user_id=user.id,
story_id=story.id,
)
story.image_url = image_url
sync_story_status(
story,
@@ -188,6 +239,7 @@ async def generate_story_stream(
{
"image_url": image_url,
"generation_status": story.generation_status,
"text_status": story.text_status,
"image_status": story.image_status,
"audio_status": story.audio_status,
"last_error": story.last_error,
@@ -208,6 +260,7 @@ async def generate_story_stream(
{
"error": str(e),
"generation_status": story.generation_status,
"text_status": story.text_status,
"image_status": story.image_status,
"audio_status": story.audio_status,
"last_error": story.last_error,
@@ -221,6 +274,7 @@ async def generate_story_stream(
{
"story_id": story.id,
"generation_status": story.generation_status,
"text_status": story.text_status,
"image_status": story.image_status,
"audio_status": story.audio_status,
"last_error": story.last_error,
@@ -296,6 +350,7 @@ async def generate_story_image(
return {
"image_url": url,
"generation_status": story.generation_status,
"text_status": story.text_status,
"image_status": story.image_status,
"audio_status": story.audio_status,
"last_error": story.last_error,

View File

@@ -67,6 +67,9 @@ class Story(Base):
generation_status: Mapped[str] = mapped_column(
String(32), nullable=False, default="narrative_ready"
)
text_status: Mapped[str] = mapped_column(
String(32), nullable=False, default="ready"
)
image_status: Mapped[str] = mapped_column(
String(32), nullable=False, default="not_requested"
)

View File

@@ -1,7 +1,7 @@
"""Story-related Pydantic schemas."""
from datetime import datetime
from typing import Literal
from typing import Any, Literal
from pydantic import BaseModel, Field
@@ -14,6 +14,7 @@ class StoryStatusMixin(BaseModel):
"""Shared generation status fields returned by story APIs."""
generation_status: str
text_status: str
image_status: str
audio_status: str
last_error: str | None = None
@@ -117,6 +118,7 @@ class GenerationResponse(StoryStatusMixin):
"""Unified generation response for the target workflow API."""
id: int
generation_job_id: str | None = None
title: str
mode: str
story_text: str | None = None
@@ -158,6 +160,68 @@ class StoryAssetRetryRequest(BaseModel):
assets: list[Literal["image", "audio"]] = Field(..., min_length=1)
class GenerationJobEventResponse(BaseModel):
"""One persisted event emitted by a generation job."""
id: int
job_id: str
story_id: int | None = None
event_type: str
status: str
message: str | None = None
event_metadata: dict[str, Any] = Field(default_factory=dict)
created_at: datetime
class GenerationJobSummaryResponse(BaseModel):
"""Generation job summary for progress lists."""
id: str
story_id: int | None = None
output_mode: str
input_type: str
status: str
current_step: str
progress_percent: int
progress_label: str
is_terminal: bool
result_snapshot: dict[str, Any] = Field(default_factory=dict)
error_message: str | None = None
created_at: datetime
updated_at: datetime
class GenerationJobDetailResponse(GenerationJobSummaryResponse):
"""Generation job detail with append-only workflow events."""
request_payload: dict[str, Any] = Field(default_factory=dict)
events: list[GenerationJobEventResponse] = Field(default_factory=list)
class GenerationProviderStatResponse(BaseModel):
"""Aggregated provider call stats for one adapter/capability pair."""
capability: str
adapter: str
call_count: int
success_count: int
failure_count: int
avg_latency_ms: float | None = None
estimated_cost_usd: float = 0.0
class GenerationProviderStatsResponse(BaseModel):
"""Provider call stats aggregated from generation job events."""
story_id: int
total_calls: int
successful_calls: int
failed_calls: int
avg_latency_ms: float | None = None
estimated_cost_usd: float = 0.0
by_provider: list[GenerationProviderStatResponse] = Field(default_factory=list)
class AchievementItem(BaseModel):
"""Achievement item returned for a story."""

View File

@@ -4,6 +4,8 @@ from __future__ import annotations
from typing import Any
from fastapi import HTTPException
from sqlalchemy import desc, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models import GenerationJob, GenerationJobEvent, Story
@@ -17,6 +19,7 @@ def _story_snapshot(story: Story | None) -> dict[str, Any]:
"story_id": story.id,
"mode": story.mode,
"generation_status": story.generation_status,
"text_status": story.text_status,
"image_status": story.image_status,
"audio_status": story.audio_status,
"retryable_assets": story.retryable_assets,
@@ -32,6 +35,48 @@ def _job_status_from_story(story: Story) -> str:
return "completed"
def _job_progress(job: GenerationJob) -> dict[str, Any]:
"""Resolve a compact progress summary for polling-oriented clients."""
if job.status == "failed":
return {
"progress_percent": 100,
"progress_label": "生成失败",
"is_terminal": True,
}
if job.status in {"completed", "degraded_completed"}:
return {
"progress_percent": 100,
"progress_label": "已完成" if job.status == "completed" else "降级完成",
"is_terminal": True,
}
progress_map: dict[str, tuple[int, str]] = {
"request_accepted": (5, "已接收请求"),
"context_prepared": (20, "上下文已准备"),
"narrative_generated": (45, "正文已生成"),
"story_saved": (60, "主记录已保存"),
"provider_call_started": (65, "Provider 调用中"),
"provider_call_succeeded": (72, "Provider 调用成功"),
"provider_call_failed": (72, "Provider 调用失败,尝试恢复"),
"cover_image_started": (75, "封面生成中"),
"storybook_images_started": (75, "绘本插图生成中"),
"audio_started": (75, "音频生成中"),
"asset_retry_started": (25, "资源重试中"),
"postprocessing_queued": (90, "后处理已排队"),
"asset_generation_completed": (100, "资源已完成"),
"asset_retry_completed": (100, "资源重试完成"),
"generation_completed": (100, "生成完成"),
}
percent, label = progress_map.get(job.current_step, (10, "生成处理中"))
return {
"progress_percent": percent,
"progress_label": label,
"is_terminal": percent >= 100,
}
async def create_generation_job(
db: AsyncSession,
*,
@@ -131,3 +176,198 @@ async def finish_generation_job(
await db.commit()
await db.refresh(job)
return job
def generation_event_to_response(event: GenerationJobEvent) -> dict[str, Any]:
"""Convert a generation event ORM object to an API response dict."""
return {
"id": event.id,
"job_id": event.job_id,
"story_id": event.story_id,
"event_type": event.event_type,
"status": event.status,
"message": event.message,
"event_metadata": event.event_metadata or {},
"created_at": event.created_at,
}
def generation_job_to_summary(job: GenerationJob) -> dict[str, Any]:
"""Convert a generation job ORM object to an API summary dict."""
progress = _job_progress(job)
return {
"id": job.id,
"story_id": job.story_id,
"output_mode": job.output_mode,
"input_type": job.input_type,
"status": job.status,
"current_step": job.current_step,
**progress,
"result_snapshot": job.result_snapshot or {},
"error_message": job.error_message,
"created_at": job.created_at,
"updated_at": job.updated_at,
}
async def get_generation_job_detail(
db: AsyncSession,
*,
job_id: str,
user_id: str,
) -> dict[str, Any]:
"""Return a user-owned generation job with its ordered event stream."""
result = await db.execute(
select(GenerationJob).where(
GenerationJob.id == job_id,
GenerationJob.user_id == user_id,
)
)
job = result.scalar_one_or_none()
if job is None:
raise HTTPException(status_code=404, detail="Generation job not found")
events = (
await db.execute(
select(GenerationJobEvent)
.where(GenerationJobEvent.job_id == job.id)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
return {
**generation_job_to_summary(job),
"request_payload": job.request_payload or {},
"events": [generation_event_to_response(event) for event in events],
}
async def list_story_generation_jobs(
db: AsyncSession,
*,
story_id: int,
user_id: str,
) -> list[dict[str, Any]]:
"""Return recent generation jobs for a user-owned story."""
jobs = (
await db.execute(
select(GenerationJob)
.where(
GenerationJob.story_id == story_id,
GenerationJob.user_id == user_id,
)
.order_by(desc(GenerationJob.created_at), desc(GenerationJob.id))
)
).scalars().all()
return [generation_job_to_summary(job) for job in jobs]
def _as_float(value: Any) -> float | None:
if isinstance(value, int | float):
return float(value)
return None
async def get_story_provider_stats(
db: AsyncSession,
*,
story_id: int,
user_id: str,
) -> dict[str, Any]:
"""Aggregate provider call telemetry from all user-owned jobs for one story."""
events = (
await db.execute(
select(GenerationJobEvent)
.join(GenerationJob, GenerationJobEvent.job_id == GenerationJob.id)
.where(
GenerationJob.story_id == story_id,
GenerationJob.user_id == user_id,
GenerationJobEvent.event_type.in_(
["provider_call_succeeded", "provider_call_failed"]
),
)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
by_key: dict[tuple[str, str], dict[str, Any]] = {}
total_latency = 0.0
latency_count = 0
total_cost = 0.0
successful_calls = 0
failed_calls = 0
for event in events:
metadata = event.event_metadata or {}
capability = str(metadata.get("capability") or "unknown")
adapter = str(metadata.get("adapter") or "unknown")
key = (capability, adapter)
bucket = by_key.setdefault(
key,
{
"capability": capability,
"adapter": adapter,
"call_count": 0,
"success_count": 0,
"failure_count": 0,
"latency_total": 0.0,
"latency_count": 0,
"estimated_cost_usd": 0.0,
},
)
bucket["call_count"] += 1
latency = _as_float(metadata.get("latency_ms"))
if latency is not None:
bucket["latency_total"] += latency
bucket["latency_count"] += 1
total_latency += latency
latency_count += 1
if event.event_type == "provider_call_succeeded":
bucket["success_count"] += 1
successful_calls += 1
cost = _as_float(metadata.get("estimated_cost_usd")) or 0.0
bucket["estimated_cost_usd"] += cost
total_cost += cost
else:
bucket["failure_count"] += 1
failed_calls += 1
by_provider = []
for bucket in by_key.values():
bucket_latency_count = bucket.pop("latency_count")
bucket_latency_total = bucket.pop("latency_total")
by_provider.append(
{
**bucket,
"avg_latency_ms": (
round(bucket_latency_total / bucket_latency_count, 2)
if bucket_latency_count
else None
),
"estimated_cost_usd": round(bucket["estimated_cost_usd"], 6),
}
)
by_provider.sort(
key=lambda item: (
str(item["capability"]),
str(item["adapter"]),
)
)
return {
"story_id": story_id,
"total_calls": successful_calls + failed_calls,
"successful_calls": successful_calls,
"failed_calls": failed_calls,
"avg_latency_ms": round(total_latency / latency_count, 2) if latency_count else None,
"estimated_cost_usd": round(total_cost, 6),
"by_provider": by_provider,
}

View File

@@ -10,6 +10,7 @@ from app.core.logging import get_logger
from app.services.adapters import AdapterConfig, AdapterRegistry
from app.services.adapters.text.models import StoryOutput
from app.services.cost_tracker import cost_tracker
from app.services.generation_jobs import record_generation_event
from app.services.provider_cache import get_providers
from app.services.provider_metrics import health_checker, metrics_collector
from app.services.provider_policy import (
@@ -22,6 +23,7 @@ from app.services.provider_policy import (
if TYPE_CHECKING:
from app.db.admin_models import Provider
from app.db.models import GenerationJob
logger = get_logger(__name__)
@@ -36,6 +38,58 @@ _round_robin_counters: dict[ProviderType, int] = {
_latency_cache: dict[str, float] = {}
def _safe_estimated_cost(adapter) -> float:
"""Return an adapter cost value that is safe to serialize in job events."""
try:
return float(adapter.estimated_cost)
except Exception:
return 0.0
async def _record_provider_event_if_present(
db: AsyncSession | None,
*,
job: "GenerationJob | None",
event_type: str,
status: str,
provider_type: ProviderType,
adapter_name: str,
strategy: RoutingStrategy,
provider_id: str | None = None,
story_id: int | None = None,
latency_ms: int | None = None,
estimated_cost: float | None = None,
error: str | None = None,
) -> None:
"""Append provider call telemetry to the active generation job."""
if db is None or job is None:
return
await record_generation_event(
db,
job=job,
story_id=story_id,
event_type=event_type,
status=status,
message=(
f"{provider_type} provider {adapter_name} {status}."
if error is None
else f"{provider_type} provider {adapter_name} failed."
),
metadata={
"capability": provider_type,
"adapter": adapter_name,
"provider_id": provider_id,
"strategy": strategy.value,
"latency_ms": latency_ms,
"estimated_cost_usd": estimated_cost,
"error": error,
},
)
def _get_api_key(config_ref: str | None, adapter_name: str) -> str:
"""根据 config_ref 或适配器名称获取 API Key。"""
# 优先使用 config_ref
@@ -228,6 +282,8 @@ async def _route_with_failover(
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
user_id: str | None = None,
generation_job: "GenerationJob | None" = None,
story_id: int | None = None,
**kwargs,
) -> T:
"""通用 provider failover 路由。
@@ -237,6 +293,8 @@ async def _route_with_failover(
strategy: 路由策略
db: 数据库会话(可选,用于指标收集和熔断检查)
user_id: 用户 ID可选用于成本追踪和预算检查
generation_job: 生成任务(可选,用于记录 provider 调用轨迹)
story_id: 故事 ID可选用于关联 provider 事件)
**kwargs: 传递给适配器的参数
"""
providers = await _get_providers_with_config(provider_type)
@@ -274,7 +332,9 @@ async def _route_with_failover(
errors.append(f"{name}: 适配器未注册")
continue
provider_id = db_provider.id if db_provider else None
provider_id = str(db_provider.id) if db_provider else None
estimated_cost: float | None = None
start_time: float | None = None
try:
logger.debug(
@@ -285,6 +345,20 @@ async def _route_with_failover(
)
adapter = adapter_class(config)
estimated_cost = _safe_estimated_cost(adapter)
await _record_provider_event_if_present(
db,
job=generation_job,
story_id=story_id,
event_type="provider_call_started",
status="running",
provider_type=provider_type,
adapter_name=name,
provider_id=provider_id,
strategy=strategy,
estimated_cost=estimated_cost,
)
# 执行并计时
start_time = time.time()
@@ -301,7 +375,7 @@ async def _route_with_failover(
provider_id=provider_id,
success=True,
latency_ms=latency_ms,
cost_usd=adapter.estimated_cost,
cost_usd=estimated_cost,
)
await health_checker.record_call_result(db, provider_id, success=True)
@@ -312,10 +386,24 @@ async def _route_with_failover(
user_id=user_id,
provider_name=name,
capability=provider_type,
estimated_cost=adapter.estimated_cost,
estimated_cost=estimated_cost,
provider_id=provider_id,
)
await _record_provider_event_if_present(
db,
job=generation_job,
story_id=story_id,
event_type="provider_call_succeeded",
status="succeeded",
provider_type=provider_type,
adapter_name=name,
provider_id=provider_id,
strategy=strategy,
latency_ms=latency_ms,
estimated_cost=estimated_cost,
)
logger.info(
"provider_success",
provider_type=provider_type,
@@ -326,6 +414,11 @@ async def _route_with_failover(
except Exception as exc:
error_msg = str(exc)
latency_ms = (
int((time.time() - start_time) * 1000)
if start_time is not None
else None
)
logger.warning(
"provider_failed",
provider_type=provider_type,
@@ -346,6 +439,21 @@ async def _route_with_failover(
db, provider_id, success=False, error=error_msg
)
await _record_provider_event_if_present(
db,
job=generation_job,
story_id=story_id,
event_type="provider_call_failed",
status="failed",
provider_type=provider_type,
adapter_name=name,
provider_id=provider_id,
strategy=strategy,
latency_ms=latency_ms,
estimated_cost=estimated_cost,
error=error_msg,
)
raise ValueError(f"No {provider_type} provider succeeded. Errors: {' | '.join(errors)}")
@@ -356,12 +464,16 @@ async def generate_story_content(
memory_context: str | None = None,
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
user_id: str | None = None,
generation_job: "GenerationJob | None" = None,
) -> StoryOutput:
"""生成或润色故事,支持 failover。"""
return await _route_with_failover(
"text",
strategy=strategy,
db=db,
user_id=user_id,
generation_job=generation_job,
input_type=input_type,
data=data,
education_theme=education_theme,
@@ -373,19 +485,42 @@ async def generate_image(
prompt: str,
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
user_id: str | None = None,
generation_job: "GenerationJob | None" = None,
story_id: int | None = None,
**kwargs,
) -> str:
"""生成图片,返回 URL支持 failover。"""
return await _route_with_failover("image", strategy=strategy, db=db, prompt=prompt, **kwargs)
return await _route_with_failover(
"image",
strategy=strategy,
db=db,
user_id=user_id,
generation_job=generation_job,
story_id=story_id,
prompt=prompt,
**kwargs,
)
async def text_to_speech(
text: str,
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
user_id: str | None = None,
generation_job: "GenerationJob | None" = None,
story_id: int | None = None,
) -> bytes:
"""文本转语音,返回 MP3 bytes支持 failover。"""
return await _route_with_failover("tts", strategy=strategy, db=db, text=text)
return await _route_with_failover(
"tts",
strategy=strategy,
db=db,
user_id=user_id,
generation_job=generation_job,
story_id=story_id,
text=text,
)
async def generate_storybook(
@@ -395,6 +530,8 @@ async def generate_storybook(
memory_context: str | None = None,
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
user_id: str | None = None,
generation_job: "GenerationJob | None" = None,
):
"""生成分页故事书,支持 failover。"""
from app.services.adapters.storybook.primary import Storybook
@@ -403,6 +540,8 @@ async def generate_storybook(
"storybook",
strategy=strategy,
db=db,
user_id=user_id,
generation_job=generation_job,
keywords=keywords,
page_count=page_count,
education_theme=education_theme,

View File

@@ -67,6 +67,43 @@ class AssetCompletionResult:
return self.status == StoryAssetStatus.READY and self.error is None
async def _record_job_event_if_present(
db: AsyncSession,
*,
job,
event_type: str,
status: str,
story_id: int | None = None,
message: str | None = None,
metadata: dict | None = None,
) -> None:
"""Append a workflow event when the caller is running under a tracked job."""
if job is None:
return
await record_generation_event(
db,
job=job,
story_id=story_id,
event_type=event_type,
status=status,
message=message,
metadata=metadata,
)
def _asset_result_metadata(result: AssetCompletionResult) -> dict:
"""Build JSON-safe metadata for asset workflow events."""
return {
"asset": result.asset,
"status": result.status.value,
"error": result.error,
"blocks_main_result": result.blocks_main_result,
}
def _build_storybook_error_message(
*,
cover_failed: bool,
@@ -125,6 +162,7 @@ async def _prepare_generation_context(
universe_id: str | None,
user_id: str,
db: AsyncSession,
job=None,
) -> tuple[str | None, str | None, str]:
"""Validate ownership and build the shared generation context."""
@@ -136,6 +174,18 @@ async def _prepare_generation_context(
resolved_universe_id,
db,
)
await _record_job_event_if_present(
db,
job=job,
event_type="context_prepared",
status="succeeded",
message="Profile, universe, and memory context were prepared.",
metadata={
"profile_id": resolved_profile_id,
"universe_id": resolved_universe_id,
"has_memory_context": bool(memory_context),
},
)
return resolved_profile_id, resolved_universe_id, memory_context
@@ -173,6 +223,7 @@ async def _persist_text_story_result(
profile_id: str | None,
universe_id: str | None,
db: AsyncSession,
job=None,
) -> Story:
"""Persist generated text content as the unified story record."""
@@ -195,6 +246,20 @@ async def _persist_text_story_result(
await db.commit()
await db.refresh(story)
_trigger_story_postprocessing(story)
await _record_job_event_if_present(
db,
job=job,
story_id=story.id,
event_type="story_saved",
status="succeeded",
message="Readable story record was saved.",
metadata={
"mode": story.mode,
"generation_status": story.generation_status,
"image_status": story.image_status,
"audio_status": story.audio_status,
},
)
return story
@@ -229,21 +294,47 @@ def _storybook_pages_to_response(pages_data: list[dict]) -> list[StorybookPageRe
async def _generate_storybook_image_assets(
storybook: Storybook,
db: AsyncSession,
*,
user_id: str,
job=None,
) -> tuple[str | None, bool, list[int]]:
"""Generate storybook cover and page images before persistence."""
final_cover_url = storybook.cover_url
cover_failed = False
failed_pages: list[int] = []
completed_pages: list[int] = []
attempted_cover = bool(storybook.cover_prompt and not storybook.cover_url)
attempted_pages = [
page.page_number
for page in storybook.pages
if page.image_prompt and not page.image_url
]
logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages))
await _record_job_event_if_present(
db,
job=job,
event_type="storybook_images_started",
status="running",
message="Storybook cover and page image generation started.",
metadata={
"attempted_cover": attempted_cover,
"attempted_pages": attempted_pages,
},
)
async def _gen_cover() -> str | None:
nonlocal cover_failed
if storybook.cover_prompt and not storybook.cover_url:
try:
return await generate_image(storybook.cover_prompt, db=db)
return await generate_image(
storybook.cover_prompt,
db=db,
user_id=user_id,
generation_job=job,
)
except Exception as exc:
cover_failed = True
logger.warning("cover_gen_failed", error=str(exc))
@@ -254,7 +345,13 @@ async def _generate_storybook_image_assets(
return
try:
page.image_url = await generate_image(page.image_prompt, db=db)
page.image_url = await generate_image(
page.image_prompt,
db=db,
user_id=user_id,
generation_job=job,
)
completed_pages.append(page.page_number)
except Exception as exc:
failed_pages.append(page.page_number)
logger.warning("page_gen_failed", page=page.page_number, error=str(exc))
@@ -270,6 +367,57 @@ async def _generate_storybook_image_assets(
final_cover_url = cover_result
logger.info("storybook_parallel_generation_complete")
if attempted_cover:
await _record_job_event_if_present(
db,
job=job,
event_type=(
"storybook_cover_image_failed"
if cover_failed
else "storybook_cover_image_succeeded"
),
status="failed" if cover_failed else "succeeded",
message=(
"Storybook cover image generation failed."
if cover_failed
else "Storybook cover image was generated."
),
metadata={"asset": "image", "scope": "cover"},
)
for page_number in sorted(completed_pages):
await _record_job_event_if_present(
db,
job=job,
event_type="storybook_page_image_succeeded",
status="succeeded",
message="Storybook page image was generated.",
metadata={"asset": "image", "scope": "page", "page_number": page_number},
)
for page_number in sorted(failed_pages):
await _record_job_event_if_present(
db,
job=job,
event_type="storybook_page_image_failed",
status="failed",
message="Storybook page image generation failed.",
metadata={"asset": "image", "scope": "page", "page_number": page_number},
)
await _record_job_event_if_present(
db,
job=job,
event_type="storybook_images_completed",
status="failed" if cover_failed or failed_pages else "succeeded",
message="Storybook image generation finished.",
metadata={
"asset": "image",
"attempted_cover": attempted_cover,
"completed_pages": sorted(completed_pages),
"failed_pages": sorted(failed_pages),
},
)
return final_cover_url, cover_failed, failed_pages
@@ -284,6 +432,7 @@ async def _persist_storybook_result(
cover_failed: bool,
failed_pages: list[int],
db: AsyncSession,
job=None,
) -> tuple[Story, list[dict]]:
"""Persist generated storybook content as the unified story record."""
@@ -317,6 +466,21 @@ async def _persist_storybook_result(
await db.commit()
await db.refresh(story)
_trigger_story_postprocessing(story)
await _record_job_event_if_present(
db,
job=job,
story_id=story.id,
event_type="story_saved",
status="succeeded",
message="Storybook record was saved.",
metadata={
"mode": story.mode,
"page_count": len(pages_data),
"generation_status": story.generation_status,
"image_status": story.image_status,
"audio_status": story.audio_status,
},
)
return story, pages_data
@@ -327,6 +491,7 @@ async def _complete_cover_image_asset(
raise_on_failure: bool = False,
last_error_prefix: str | None = None,
log_event: str = "cover_asset_generation_failed",
job=None,
) -> AssetCompletionResult:
"""Generate or retry a text story cover through one asset workflow."""
@@ -335,18 +500,43 @@ async def _complete_cover_image_asset(
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
await db.commit()
await _record_job_event_if_present(
db,
job=job,
story_id=story.id,
event_type="cover_image_started",
status="running",
message="Cover image generation started.",
metadata={"asset": "image", "cover_prompt_present": True},
)
try:
image_url = await generate_image(story.cover_prompt, db=db)
image_url = await generate_image(
story.cover_prompt,
db=db,
user_id=story.user_id,
generation_job=job,
story_id=story.id,
)
story.image_url = image_url
sync_story_status(story, image_status=StoryAssetStatus.READY)
await db.commit()
return AssetCompletionResult(
result = AssetCompletionResult(
asset="cover_image",
status=StoryAssetStatus.READY,
value=image_url,
blocks_main_result=raise_on_failure,
)
await _record_job_event_if_present(
db,
job=job,
story_id=story.id,
event_type="cover_image_succeeded",
status="succeeded",
message="Cover image was generated.",
metadata=_asset_result_metadata(result),
)
return result
except Exception as exc:
provider_error = str(exc)
last_error = (
@@ -362,18 +552,28 @@ async def _complete_cover_image_asset(
await db.commit()
logger.warning(log_event, story_id=story.id, error=provider_error)
result = AssetCompletionResult(
asset="cover_image",
status=StoryAssetStatus.FAILED,
error=provider_error,
blocks_main_result=raise_on_failure,
)
await _record_job_event_if_present(
db,
job=job,
story_id=story.id,
event_type="cover_image_failed",
status="failed",
message="Cover image generation failed.",
metadata=_asset_result_metadata(result),
)
if raise_on_failure:
raise HTTPException(
status_code=500,
detail=f"Image generation failed: {provider_error}",
) from exc
return AssetCompletionResult(
asset="cover_image",
status=StoryAssetStatus.FAILED,
error=provider_error,
blocks_main_result=raise_on_failure,
)
return result
def _get_storybook_pages_data(story: Story) -> list[dict]:
@@ -385,6 +585,8 @@ def _get_storybook_pages_data(story: Story) -> list[dict]:
async def _complete_storybook_image_assets(
story: Story,
db: AsyncSession,
*,
job=None,
) -> AssetCompletionResult:
"""Complete missing cover/page images for a persisted storybook."""
@@ -397,13 +599,38 @@ async def _complete_storybook_image_assets(
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
await db.commit()
await _record_job_event_if_present(
db,
job=job,
story_id=story.id,
event_type="storybook_images_started",
status="running",
message="Storybook missing image completion started.",
metadata={"asset": "image"},
)
cover_failed = False
failed_pages: list[int] = []
completed_pages: list[int] = []
if story.cover_prompt and not story.image_url:
try:
story.image_url = await generate_image(story.cover_prompt, db=db)
story.image_url = await generate_image(
story.cover_prompt,
db=db,
user_id=story.user_id,
generation_job=job,
story_id=story.id,
)
await _record_job_event_if_present(
db,
job=job,
story_id=story.id,
event_type="storybook_cover_image_succeeded",
status="succeeded",
message="Storybook cover image was generated.",
metadata={"asset": "image", "scope": "cover"},
)
except Exception as exc:
cover_failed = True
logger.warning(
@@ -411,13 +638,40 @@ async def _complete_storybook_image_assets(
story_id=story.id,
error=str(exc),
)
await _record_job_event_if_present(
db,
job=job,
story_id=story.id,
event_type="storybook_cover_image_failed",
status="failed",
message="Storybook cover image generation failed.",
metadata={"asset": "image", "scope": "cover", "error": str(exc)},
)
for page in pages_data:
if not page.get("image_prompt") or page.get("image_url"):
continue
try:
page["image_url"] = await generate_image(page["image_prompt"], db=db)
page["image_url"] = await generate_image(
page["image_prompt"],
db=db,
user_id=story.user_id,
generation_job=job,
story_id=story.id,
)
page_number = page.get("page_number")
if isinstance(page_number, int):
completed_pages.append(page_number)
await _record_job_event_if_present(
db,
job=job,
story_id=story.id,
event_type="storybook_page_image_succeeded",
status="succeeded",
message="Storybook page image was generated.",
metadata={"asset": "image", "scope": "page", "page_number": page_number},
)
except Exception as exc:
page_number = page.get("page_number")
if isinstance(page_number, int):
@@ -428,6 +682,20 @@ async def _complete_storybook_image_assets(
page=page_number,
error=str(exc),
)
await _record_job_event_if_present(
db,
job=job,
story_id=story.id,
event_type="storybook_page_image_failed",
status="failed",
message="Storybook page image generation failed.",
metadata={
"asset": "image",
"scope": "page",
"page_number": page_number,
"error": str(exc),
},
)
story.pages = pages_data
error_message = _build_storybook_error_message(
@@ -446,12 +714,26 @@ async def _complete_storybook_image_assets(
last_error=error_message,
)
await db.commit()
return AssetCompletionResult(
result = AssetCompletionResult(
asset="storybook_images",
status=image_status,
value=story.image_url,
error=error_message,
)
await _record_job_event_if_present(
db,
job=job,
story_id=story.id,
event_type="storybook_images_completed",
status="failed" if error_message else "succeeded",
message="Storybook image completion finished.",
metadata={
**_asset_result_metadata(result),
"completed_pages": sorted(completed_pages),
"failed_pages": sorted(failed_pages),
},
)
return result
async def _read_cached_audio_asset(story: Story, db: AsyncSession) -> bytes | None:
@@ -482,6 +764,7 @@ async def _complete_audio_asset(
db: AsyncSession,
*,
raise_on_failure: bool = True,
job=None,
) -> AssetCompletionResult:
"""Complete TTS audio generation through one asset workflow."""
@@ -490,32 +773,67 @@ async def _complete_audio_asset(
cached_audio = await _read_cached_audio_asset(story, db)
if cached_audio is not None:
return AssetCompletionResult(
result = AssetCompletionResult(
asset="audio",
status=StoryAssetStatus.READY,
value=cached_audio,
blocks_main_result=raise_on_failure,
)
await _record_job_event_if_present(
db,
job=job,
story_id=story.id,
event_type="audio_cache_hit",
status="succeeded",
message="Cached story audio was reused.",
metadata=_asset_result_metadata(result),
)
return result
from app.services.provider_router import text_to_speech
sync_story_status(story, audio_status=StoryAssetStatus.GENERATING)
await db.commit()
await _record_job_event_if_present(
db,
job=job,
story_id=story.id,
event_type="audio_started",
status="running",
message="Story audio generation started.",
metadata={"asset": "audio"},
)
try:
audio_data = await text_to_speech(story.story_text, db=db)
audio_data = await text_to_speech(
story.story_text,
db=db,
user_id=story.user_id,
generation_job=job,
story_id=story.id,
)
story.audio_path = write_story_audio_cache(story.id, audio_data)
sync_story_status(
story,
audio_status=StoryAssetStatus.READY,
)
await db.commit()
return AssetCompletionResult(
result = AssetCompletionResult(
asset="audio",
status=StoryAssetStatus.READY,
value=audio_data,
blocks_main_result=raise_on_failure,
)
await _record_job_event_if_present(
db,
job=job,
story_id=story.id,
event_type="audio_succeeded",
status="succeeded",
message="Story audio was generated and cached.",
metadata=_asset_result_metadata(result),
)
return result
except Exception as exc:
provider_error = str(exc)
story.audio_path = None
@@ -527,18 +845,28 @@ async def _complete_audio_asset(
await db.commit()
logger.error("audio_generation_failed", story_id=story.id, error=provider_error)
result = AssetCompletionResult(
asset="audio",
status=StoryAssetStatus.FAILED,
error=provider_error,
blocks_main_result=raise_on_failure,
)
await _record_job_event_if_present(
db,
job=job,
story_id=story.id,
event_type="audio_failed",
status="failed",
message="Story audio generation failed.",
metadata=_asset_result_metadata(result),
)
if raise_on_failure:
raise HTTPException(
status_code=500,
detail=f"Audio generation failed: {provider_error}",
) from exc
return AssetCompletionResult(
asset="audio",
status=StoryAssetStatus.FAILED,
error=provider_error,
blocks_main_result=raise_on_failure,
)
return result
async def validate_profile_and_universe(
@@ -586,6 +914,8 @@ async def generate_and_save_story(
request: GenerateRequest,
user_id: str,
db: AsyncSession,
*,
job=None,
) -> Story:
"""Generate generic story content and save to DB."""
profile_id, universe_id, memory_context = await _prepare_generation_context(
@@ -593,21 +923,32 @@ async def generate_and_save_story(
universe_id=request.universe_id,
user_id=user_id,
db=db,
job=job,
)
try:
result = await generate_story_content(
input_type=request.type,
data=request.data,
education_theme=request.education_theme,
memory_context=memory_context,
db=db,
)
education_theme=request.education_theme,
memory_context=memory_context,
db=db,
user_id=user_id,
generation_job=job,
)
except Exception as exc:
raise HTTPException(
status_code=502,
detail="Story generation failed, please try again.",
) from exc
await _record_job_event_if_present(
db,
job=job,
event_type="narrative_generated",
status="succeeded",
message="Story narrative was generated.",
metadata={"mode": result.mode, "title": result.title},
)
return await _persist_text_story_result(
result=result,
@@ -615,6 +956,7 @@ async def generate_and_save_story(
profile_id=profile_id,
universe_id=universe_id,
db=db,
job=job,
)
@@ -622,9 +964,11 @@ async def generate_full_story_service(
request: GenerateRequest,
user_id: str,
db: AsyncSession,
*,
job=None,
) -> FullStoryResponse:
"""Generate story with parallel image generation."""
story = await generate_and_save_story(request, user_id, db)
story = await generate_and_save_story(request, user_id, db, job=job)
image_url: str | None = None
errors: dict[str, str | None] = {}
@@ -633,6 +977,7 @@ async def generate_full_story_service(
story,
db,
log_event="image_generation_failed",
job=job,
)
if image_result.succeeded and isinstance(image_result.value, str):
image_url = image_result.value
@@ -651,6 +996,7 @@ async def generate_full_story_service(
child_profile_id=story.child_profile_id,
universe_id=story.universe_id,
generation_status=story.generation_status,
text_status=story.text_status,
image_status=story.image_status,
audio_status=story.audio_status,
last_error=story.last_error,
@@ -662,6 +1008,8 @@ async def generate_storybook_service(
request: StorybookRequest,
user_id: str,
db: AsyncSession,
*,
job=None,
) -> StorybookResponse:
"""Generate storybook with parallel image generation for pages."""
profile_id, universe_id, memory_context = await _prepare_generation_context(
@@ -669,6 +1017,7 @@ async def generate_storybook_service(
universe_id=request.universe_id,
user_id=user_id,
db=db,
job=job,
)
logger.info(
@@ -684,13 +1033,27 @@ async def generate_storybook_service(
storybook = await generate_storybook(
keywords=request.keywords,
page_count=request.page_count,
education_theme=request.education_theme,
memory_context=memory_context,
db=db,
)
except Exception as e:
education_theme=request.education_theme,
memory_context=memory_context,
db=db,
user_id=user_id,
generation_job=job,
)
except Exception as e:
logger.error("storybook_generation_failed", error=str(e))
raise HTTPException(status_code=500, detail=f"故事书生成失败: {e}")
await _record_job_event_if_present(
db,
job=job,
event_type="narrative_generated",
status="succeeded",
message="Storybook narrative and page plan were generated.",
metadata={
"mode": "storybook",
"title": storybook.title,
"page_count": len(storybook.pages),
},
)
final_cover_url = storybook.cover_url
cover_failed = False
@@ -701,7 +1064,12 @@ async def generate_storybook_service(
final_cover_url,
cover_failed,
failed_pages,
) = await _generate_storybook_image_assets(storybook, db)
) = await _generate_storybook_image_assets(
storybook,
db,
user_id=user_id,
job=job,
)
story, pages_data = await _persist_storybook_result(
storybook=storybook,
@@ -713,6 +1081,7 @@ async def generate_storybook_service(
cover_failed=cover_failed,
failed_pages=failed_pages,
db=db,
job=job,
)
response_pages = _storybook_pages_to_response(pages_data)
@@ -726,6 +1095,7 @@ async def generate_storybook_service(
cover_prompt=storybook.cover_prompt,
cover_url=final_cover_url,
generation_status=story.generation_status,
text_status=story.text_status,
image_status=story.image_status,
audio_status=story.audio_status,
last_error=story.last_error,
@@ -797,6 +1167,7 @@ async def _generate_generation_service_with_job(
),
user_id,
db,
job=job,
)
if storybook.id is None:
raise HTTPException(status_code=500, detail="Storybook generation did not persist.")
@@ -812,6 +1183,7 @@ async def _generate_generation_service_with_job(
)
return GenerationResponse(
id=storybook.id,
generation_job_id=job.id,
title=storybook.title,
mode="storybook",
pages=storybook.pages,
@@ -821,6 +1193,7 @@ async def _generate_generation_service_with_job(
main_character=storybook.main_character,
art_style=storybook.art_style,
generation_status=storybook.generation_status,
text_status=saved_story.text_status,
image_status=storybook.image_status,
audio_status=storybook.audio_status,
last_error=storybook.last_error,
@@ -838,7 +1211,7 @@ async def _generate_generation_service_with_job(
)
if request.generate_images:
story = await generate_full_story_service(generate_request, user_id, db)
story = await generate_full_story_service(generate_request, user_id, db, job=job)
saved_story = await get_story_detail(story.id, user_id, db)
await _record_postprocessing_event_if_needed(db, job=job, story=saved_story)
await finish_generation_job(
@@ -850,6 +1223,7 @@ async def _generate_generation_service_with_job(
)
return GenerationResponse(
id=story.id,
generation_job_id=job.id,
title=story.title,
mode=story.mode,
story_text=story.story_text,
@@ -859,6 +1233,7 @@ async def _generate_generation_service_with_job(
audio_ready=story.audio_ready,
errors=story.errors,
generation_status=story.generation_status,
text_status=saved_story.text_status,
image_status=story.image_status,
audio_status=story.audio_status,
last_error=story.last_error,
@@ -867,7 +1242,7 @@ async def _generate_generation_service_with_job(
retryable_assets=saved_story.retryable_assets,
)
story = await generate_and_save_story(generate_request, user_id, db)
story = await generate_and_save_story(generate_request, user_id, db, job=job)
await _record_postprocessing_event_if_needed(db, job=job, story=story)
await finish_generation_job(
db,
@@ -878,6 +1253,7 @@ async def _generate_generation_service_with_job(
)
return GenerationResponse(
id=story.id,
generation_job_id=job.id,
title=story.title,
mode=story.mode,
story_text=story.story_text,
@@ -885,6 +1261,7 @@ async def _generate_generation_service_with_job(
image_url=story.image_url,
cover_url=story.image_url,
generation_status=story.generation_status,
text_status=story.text_status,
image_status=story.image_status,
audio_status=story.audio_status,
last_error=story.last_error,
@@ -954,7 +1331,7 @@ async def create_story_from_result(
)
async def _retry_cover_image_asset(story: Story, db: AsyncSession) -> None:
async def _retry_cover_image_asset(story: Story, db: AsyncSession, *, job=None) -> None:
"""Retry cover generation for a text story."""
await _complete_cover_image_asset(
@@ -962,19 +1339,25 @@ async def _retry_cover_image_asset(story: Story, db: AsyncSession) -> None:
db,
last_error_prefix="封面生成失败",
log_event="cover_asset_retry_failed",
job=job,
)
async def _retry_storybook_image_assets(story: Story, db: AsyncSession) -> None:
async def _retry_storybook_image_assets(
story: Story,
db: AsyncSession,
*,
job=None,
) -> None:
"""Retry missing storybook cover/page images."""
await _complete_storybook_image_assets(story, db)
await _complete_storybook_image_assets(story, db, job=job)
async def _retry_audio_asset(story: Story, db: AsyncSession) -> None:
async def _retry_audio_asset(story: Story, db: AsyncSession, *, job=None) -> None:
"""Retry audio generation while preserving persisted status on provider failure."""
await _complete_audio_asset(story, db, raise_on_failure=False)
await _complete_audio_asset(story, db, raise_on_failure=False, job=job)
async def retry_story_assets(
@@ -1009,12 +1392,12 @@ async def retry_story_assets(
if "image" in requested_assets:
if story.mode == "storybook":
await _retry_storybook_image_assets(story, db)
await _retry_storybook_image_assets(story, db, job=job)
else:
await _retry_cover_image_asset(story, db)
await _retry_cover_image_asset(story, db, job=job)
if "audio" in requested_assets:
await _retry_audio_asset(story, db)
await _retry_audio_asset(story, db, job=job)
story = await get_story_detail(story_id, user_id, db)
await finish_generation_job(
@@ -1075,6 +1458,7 @@ async def generate_story_cover(
db,
raise_on_failure=True,
log_event="cover_generation_failed",
job=job,
)
story = await get_story_detail(story_id, user_id, db)
await finish_generation_job(
@@ -1121,7 +1505,12 @@ async def generate_story_audio(
try:
story = await get_story_detail(story_id, user_id, db)
audio_result = await _complete_audio_asset(story, db, raise_on_failure=True)
audio_result = await _complete_audio_asset(
story,
db,
raise_on_failure=True,
job=job,
)
story = await get_story_detail(story_id, user_id, db)
await finish_generation_job(
db,

View File

@@ -10,6 +10,7 @@ class StoryGenerationStatus(str, Enum):
"""Overall story generation lifecycle."""
NARRATIVE_READY = "narrative_ready"
PARTIAL_READY = "partial_ready"
ASSETS_GENERATING = "assets_generating"
COMPLETED = "completed"
DEGRADED_COMPLETED = "degraded_completed"
@@ -30,7 +31,10 @@ class StoryLike(Protocol):
story_text: str | None
pages: list[dict] | None
cover_prompt: str | None
image_url: str | None
generation_status: str
text_status: str
image_status: str
audio_status: str
last_error: str | None
@@ -55,6 +59,37 @@ def has_narrative_content(story: StoryLike) -> bool:
return bool(story.story_text) or bool(story.pages)
def _has_retryable_image(story: StoryLike, image_status: StoryAssetStatus) -> bool:
if image_status in {StoryAssetStatus.READY, StoryAssetStatus.GENERATING}:
return False
pages = story.pages or []
has_missing_page_image = any(
isinstance(page, dict)
and page.get("image_prompt")
and not page.get("image_url")
for page in pages
)
return bool(story.cover_prompt and not story.image_url) or has_missing_page_image
def _has_pending_assets(
story: StoryLike,
*,
image_status: StoryAssetStatus,
audio_status: StoryAssetStatus,
) -> bool:
"""Whether readable content still has optional assets to complete."""
if _has_retryable_image(story, image_status):
return True
return bool(story.story_text) and audio_status not in {
StoryAssetStatus.READY,
StoryAssetStatus.GENERATING,
}
def resolve_story_generation_status(story: StoryLike) -> StoryGenerationStatus:
"""Derive the overall status from narrative and asset states."""
@@ -70,6 +105,9 @@ def resolve_story_generation_status(story: StoryLike) -> StoryGenerationStatus:
if StoryAssetStatus.FAILED in (image_status, audio_status):
return StoryGenerationStatus.DEGRADED_COMPLETED
if _has_pending_assets(story, image_status=image_status, audio_status=audio_status):
return StoryGenerationStatus.PARTIAL_READY
if (
image_status == StoryAssetStatus.NOT_REQUESTED
and audio_status == StoryAssetStatus.NOT_REQUESTED
@@ -105,6 +143,12 @@ def sync_story_status(
if last_error is not _ERROR_UNSET:
story.last_error = last_error
story.text_status = (
StoryAssetStatus.READY.value
if has_narrative_content(story)
else StoryAssetStatus.FAILED.value
)
generation_status = resolve_story_generation_status(story)
story.generation_status = generation_status.value

View File

@@ -66,7 +66,8 @@ async def test_story(db_session: AsyncSession, test_user: User) -> Story:
story_text="从前有一只小兔子。",
cover_prompt="A cute rabbit in a forest",
mode="generated",
generation_status="narrative_ready",
generation_status="partial_ready",
text_status="ready",
image_status="not_requested",
audio_status="not_requested",
)
@@ -102,6 +103,7 @@ async def storybook_story(db_session: AsyncSession, test_user: User) -> Story:
image_url="https://example.com/storybook-cover.png",
mode="storybook",
generation_status="degraded_completed",
text_status="ready",
image_status="failed",
audio_status="not_requested",
last_error="第 2 页插图生成失败",
@@ -123,6 +125,7 @@ async def degraded_story_with_text(db_session: AsyncSession, test_user: User) ->
cover_prompt="A rabbit under the moon",
mode="generated",
generation_status="degraded_completed",
text_status="ready",
image_status="failed",
audio_status="not_requested",
last_error="封面生成失败",

View File

@@ -1,5 +1,7 @@
"""Generation job tracking tests."""
from unittest.mock import AsyncMock, patch
import pytest
from httpx import ASGITransport, AsyncClient
from sqlalchemy import select
@@ -7,10 +9,37 @@ from sqlalchemy import select
from app.db.database import get_db
from app.db.models import GenerationJob, GenerationJobEvent
from app.main import app
from app.services.adapters import AdapterConfig
from app.services.adapters.storybook.primary import Storybook, StorybookPage
from app.services.adapters.text.models import StoryOutput
from app.services.generation_jobs import create_generation_job, record_generation_event
pytestmark = pytest.mark.asyncio
def build_storybook_output() -> Storybook:
"""Create a reusable mocked storybook payload."""
return Storybook(
title="森林里的发光冒险",
main_character="小兔子露露",
art_style="温暖水彩",
cover_prompt="A glowing forest storybook cover",
pages=[
StorybookPage(
page_number=1,
text="露露第一次走进会发光的森林。",
image_prompt="Lulu entering a glowing forest",
),
StorybookPage(
page_number=2,
text="她遇到了一只会唱歌的萤火虫。",
image_prompt="Lulu meeting a singing firefly",
),
],
)
async def test_unified_generation_records_job_events_and_retryable_assets(
db_session,
test_user,
@@ -39,8 +68,9 @@ async def test_unified_generation_records_job_events_and_retryable_assets(
assert response.status_code == 200
data = response.json()
assert data["generation_status"] == "narrative_ready"
assert data["generation_status"] == "partial_ready"
assert data["retryable_assets"] == ["image", "audio"]
assert data["generation_job_id"]
jobs = (
await db_session.execute(
@@ -55,6 +85,7 @@ async def test_unified_generation_records_job_events_and_retryable_assets(
assert job.status == "completed"
assert job.current_step == "generation_completed"
assert job.result_snapshot["retryable_assets"] == ["image", "audio"]
assert data["generation_job_id"] == job.id
events = (
await db_session.execute(
@@ -65,8 +96,37 @@ async def test_unified_generation_records_job_events_and_retryable_assets(
).scalars().all()
assert [event.event_type for event in events] == [
"request_accepted",
"context_prepared",
"narrative_generated",
"story_saved",
"generation_completed",
]
assert events[1].event_metadata["has_memory_context"] is False
assert events[2].event_metadata["title"] == "小兔子的冒险"
assert events[3].story_id == data["id"]
detail_response = await client.get(f"/api/generations/jobs/{job.id}")
assert detail_response.status_code == 200
detail = detail_response.json()
assert detail["id"] == job.id
assert detail["story_id"] == data["id"]
assert detail["progress_percent"] == 100
assert detail["progress_label"] == "已完成"
assert detail["is_terminal"] is True
assert [event["event_type"] for event in detail["events"]] == [
"request_accepted",
"context_prepared",
"narrative_generated",
"story_saved",
"generation_completed",
]
list_response = await client.get(f"/api/generations/{data['id']}/jobs")
assert list_response.status_code == 200
job_list = list_response.json()
assert [item["id"] for item in job_list] == [job.id]
assert job_list[0]["progress_percent"] == 100
assert job_list[0]["is_terminal"] is True
finally:
app.dependency_overrides.clear()
@@ -122,7 +182,252 @@ async def test_asset_retry_records_job_events_and_updates_retryable_assets(
assert [event.event_type for event in events] == [
"request_accepted",
"asset_retry_started",
"cover_image_started",
"cover_image_succeeded",
"asset_retry_completed",
]
assert events[3].event_metadata["asset"] == "cover_image"
finally:
app.dependency_overrides.clear()
async def test_storybook_generation_records_page_image_events(
db_session,
auth_token,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
try:
with patch(
"app.services.story_service.generate_storybook",
new_callable=AsyncMock,
) as mock_storybook:
with patch(
"app.services.story_service.generate_image",
new_callable=AsyncMock,
) as mock_image:
mock_storybook.return_value = build_storybook_output()
mock_image.side_effect = [
"https://example.com/storybook-cover.png",
"https://example.com/storybook-page-1.png",
"https://example.com/storybook-page-2.png",
]
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post(
"/api/generations",
json={
"output_mode": "storybook",
"type": "keywords",
"data": "森林, 发光, 友情",
"page_count": 6,
"generate_images": True,
},
)
assert response.status_code == 200
data = response.json()
assert data["mode"] == "storybook"
assert data["image_status"] == "ready"
job = (
await db_session.execute(
select(GenerationJob).where(
GenerationJob.story_id == data["id"],
GenerationJob.output_mode == "storybook",
)
)
).scalar_one()
events = (
await db_session.execute(
select(GenerationJobEvent)
.where(GenerationJobEvent.job_id == job.id)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
assert [event.event_type for event in events] == [
"request_accepted",
"context_prepared",
"narrative_generated",
"storybook_images_started",
"storybook_cover_image_succeeded",
"storybook_page_image_succeeded",
"storybook_page_image_succeeded",
"storybook_images_completed",
"story_saved",
"generation_completed",
]
page_events = [
event
for event in events
if event.event_type == "storybook_page_image_succeeded"
]
assert [event.event_metadata["page_number"] for event in page_events] == [1, 2]
assert events[7].event_metadata["completed_pages"] == [1, 2]
finally:
app.dependency_overrides.clear()
async def test_provider_call_events_record_latency_and_cost(
db_session,
test_user,
):
from app.services import provider_router
mock_result = StoryOutput(
mode="generated",
title="带供应商轨迹的故事",
story_text="一只小鹿学会了复盘。",
cover_prompt_suggestion="A deer with a golden bookmark",
)
class MockAdapter:
estimated_cost = 0.0123
def __init__(self, config):
self.config = config
async def execute(self, **kwargs):
return mock_result
job = await create_generation_job(
db_session,
user_id=test_user.id,
output_mode="story",
input_type="keywords",
request_payload={"data": "小鹿"},
)
with patch.object(
provider_router,
"_get_providers_with_config",
new_callable=AsyncMock,
) as mock_providers:
mock_providers.return_value = [("demo", AdapterConfig(api_key=""), None)]
with patch.object(provider_router.AdapterRegistry, "get", return_value=MockAdapter):
result = await provider_router.generate_story_content(
input_type="keywords",
data="小鹿",
db=db_session,
generation_job=job,
)
assert result == mock_result
events = (
await db_session.execute(
select(GenerationJobEvent)
.where(GenerationJobEvent.job_id == job.id)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
assert [event.event_type for event in events] == [
"request_accepted",
"provider_call_started",
"provider_call_succeeded",
]
provider_event = events[2]
assert provider_event.event_metadata["capability"] == "text"
assert provider_event.event_metadata["adapter"] == "demo"
assert provider_event.event_metadata["strategy"] == "priority"
assert provider_event.event_metadata["latency_ms"] >= 0
assert provider_event.event_metadata["estimated_cost_usd"] == 0.0123
async def test_story_provider_stats_aggregate_job_events(
db_session,
auth_token,
degraded_story_with_text,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
job = await create_generation_job(
db_session,
user_id=degraded_story_with_text.user_id,
output_mode="asset_retry",
input_type="image",
request_payload={"assets": ["image"]},
story_id=degraded_story_with_text.id,
)
await record_generation_event(
db_session,
job=job,
story_id=degraded_story_with_text.id,
event_type="provider_call_succeeded",
status="succeeded",
metadata={
"capability": "image",
"adapter": "demo",
"strategy": "priority",
"latency_ms": 42,
"estimated_cost_usd": 0.01,
},
)
await record_generation_event(
db_session,
job=job,
story_id=degraded_story_with_text.id,
event_type="provider_call_failed",
status="failed",
metadata={
"capability": "image",
"adapter": "cqtai",
"strategy": "priority",
"latency_ms": 120,
"estimated_cost_usd": 0.02,
"error": "timeout",
},
)
transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.get(
f"/api/generations/{degraded_story_with_text.id}/provider-stats"
)
assert response.status_code == 200
data = response.json()
assert data["story_id"] == degraded_story_with_text.id
assert data["total_calls"] == 2
assert data["successful_calls"] == 1
assert data["failed_calls"] == 1
assert data["avg_latency_ms"] == 81.0
assert data["estimated_cost_usd"] == 0.01
assert data["by_provider"] == [
{
"capability": "image",
"adapter": "cqtai",
"call_count": 1,
"success_count": 0,
"failure_count": 1,
"avg_latency_ms": 120.0,
"estimated_cost_usd": 0.0,
},
{
"capability": "image",
"adapter": "demo",
"call_count": 1,
"success_count": 1,
"failure_count": 0,
"avg_latency_ms": 42.0,
"estimated_cost_usd": 0.01,
},
]
finally:
app.dependency_overrides.clear()

View File

@@ -76,7 +76,8 @@ class TestStoryGenerate:
assert "title" in data
assert "story_text" in data
assert data["mode"] == "generated"
assert data["generation_status"] == "narrative_ready"
assert data["generation_status"] == "partial_ready"
assert data["text_status"] == "ready"
assert data["image_status"] == "not_requested"
assert data["audio_status"] == "not_requested"
assert data["last_error"] is None
@@ -101,7 +102,8 @@ class TestStoryList:
assert len(data) == 1
assert data[0]["id"] == test_story.id
assert data[0]["title"] == test_story.title
assert data[0]["generation_status"] == "narrative_ready"
assert data[0]["generation_status"] == "partial_ready"
assert data[0]["text_status"] == "ready"
assert data[0]["image_status"] == "not_requested"
assert data[0]["audio_status"] == "not_requested"
@@ -133,7 +135,8 @@ class TestStoryDetail:
assert data["id"] == test_story.id
assert data["title"] == test_story.title
assert data["story_text"] == test_story.story_text
assert data["generation_status"] == "narrative_ready"
assert data["generation_status"] == "partial_ready"
assert data["text_status"] == "ready"
assert data["image_status"] == "not_requested"
assert data["audio_status"] == "not_requested"
assert data["last_error"] is None
@@ -250,7 +253,7 @@ class TestAudio:
detail_response = auth_client.get(f"/api/stories/{test_story.id}")
detail = detail_response.json()
assert detail["audio_status"] == "ready"
assert detail["generation_status"] == "completed"
assert detail["generation_status"] == "partial_ready"
assert detail["last_error"] is None
def test_get_audio_regenerates_when_cache_file_is_missing(
@@ -335,7 +338,7 @@ class TestGenerateFull:
assert data["image_url"] == "https://example.com/image.png"
assert data["audio_ready"] is False
assert data["errors"] == {}
assert data["generation_status"] == "completed"
assert data["generation_status"] == "partial_ready"
assert data["image_status"] == "ready"
assert data["audio_status"] == "not_requested"
assert data["last_error"] is None
@@ -412,7 +415,7 @@ class TestUnifiedGenerations:
assert data["image_url"] == "https://example.com/image.png"
assert data["cover_url"] == "https://example.com/image.png"
assert data["pages"] is None
assert data["generation_status"] == "completed"
assert data["generation_status"] == "partial_ready"
assert data["image_status"] == "ready"
assert data["audio_status"] == "not_requested"
assert data["errors"] == {}
@@ -436,7 +439,7 @@ class TestUnifiedGenerations:
data = response.json()
assert data["mode"] == "generated"
assert data["image_url"] is None
assert data["generation_status"] == "narrative_ready"
assert data["generation_status"] == "partial_ready"
assert data["image_status"] == "not_requested"
def test_create_story_generation_image_failure(
@@ -530,7 +533,7 @@ class TestUnifiedGenerations:
assert response.status_code == 200
data = response.json()
assert data["image_url"] == "https://example.com/image.png"
assert data["generation_status"] == "completed"
assert data["generation_status"] == "partial_ready"
assert data["image_status"] == "ready"
@@ -551,7 +554,7 @@ class TestImageGenerateSuccess:
)
data = response.json()
assert data["image_url"] == "https://example.com/image.png"
assert data["generation_status"] == "completed"
assert data["generation_status"] == "partial_ready"
assert data["image_status"] == "ready"
assert data["audio_status"] == "not_requested"
assert data["last_error"] is None
@@ -578,7 +581,7 @@ class TestAssetRetry:
)
data = response.json()
assert data["image_url"] == "https://example.com/image.png"
assert data["generation_status"] == "completed"
assert data["generation_status"] == "partial_ready"
assert data["image_status"] == "ready"
assert data["audio_status"] == "not_requested"
assert data["last_error"] is None
@@ -629,7 +632,7 @@ class TestAssetRetry:
assert response.status_code == 200
data = response.json()
assert data["generation_status"] == "completed"
assert data["generation_status"] == "partial_ready"
assert data["image_status"] == "not_requested"
assert data["audio_status"] == "ready"
assert data["last_error"] is None

View File

@@ -0,0 +1,109 @@
"""Tests for derived story generation statuses."""
from types import SimpleNamespace
from app.services.story_status import (
StoryAssetStatus,
StoryGenerationStatus,
resolve_story_generation_status,
sync_story_status,
)
def make_story(**overrides):
data = {
"story_text": "Once upon a time.",
"pages": None,
"cover_prompt": "A warm forest cover",
"image_url": None,
"generation_status": "narrative_ready",
"text_status": "ready",
"image_status": "not_requested",
"audio_status": "not_requested",
"last_error": None,
}
data.update(overrides)
return SimpleNamespace(**data)
def test_text_story_without_assets_is_partial_ready():
story = make_story()
sync_story_status(story)
assert story.text_status == "ready"
assert story.generation_status == StoryGenerationStatus.PARTIAL_READY.value
def test_text_story_with_all_assets_is_completed():
story = make_story(
image_url="https://example.com/cover.png",
image_status="ready",
audio_status="ready",
)
assert resolve_story_generation_status(story) == StoryGenerationStatus.COMPLETED
def test_failed_asset_keeps_readable_story_degraded():
story = make_story(image_status="failed", last_error="cover failed")
sync_story_status(story)
assert story.text_status == "ready"
assert story.generation_status == StoryGenerationStatus.DEGRADED_COMPLETED.value
assert story.last_error == "cover failed"
def test_storybook_missing_page_image_is_partial_ready():
story = make_story(
story_text=None,
pages=[
{
"page_number": 1,
"text": "Page one",
"image_prompt": "Page one image",
"image_url": "https://example.com/page-1.png",
},
{
"page_number": 2,
"text": "Page two",
"image_prompt": "Page two image",
"image_url": None,
},
],
cover_prompt="Storybook cover",
image_url="https://example.com/cover.png",
image_status="not_requested",
)
assert resolve_story_generation_status(story) == StoryGenerationStatus.PARTIAL_READY
def test_storybook_with_all_images_is_completed():
story = make_story(
story_text=None,
pages=[
{
"page_number": 1,
"text": "Page one",
"image_prompt": "Page one image",
"image_url": "https://example.com/page-1.png",
},
],
cover_prompt="Storybook cover",
image_url="https://example.com/cover.png",
image_status="ready",
audio_status="not_requested",
)
assert resolve_story_generation_status(story) == StoryGenerationStatus.COMPLETED
def test_missing_narrative_sets_text_failed():
story = make_story(story_text=None, pages=None)
sync_story_status(story, image_status=StoryAssetStatus.NOT_REQUESTED)
assert story.text_status == "failed"
assert story.generation_status == StoryGenerationStatus.FAILED.value