feat: add generation trace and partial-ready workflow status
This commit is contained in:
@@ -4,6 +4,8 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import desc, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models import GenerationJob, GenerationJobEvent, Story
|
||||
@@ -17,6 +19,7 @@ def _story_snapshot(story: Story | None) -> dict[str, Any]:
|
||||
"story_id": story.id,
|
||||
"mode": story.mode,
|
||||
"generation_status": story.generation_status,
|
||||
"text_status": story.text_status,
|
||||
"image_status": story.image_status,
|
||||
"audio_status": story.audio_status,
|
||||
"retryable_assets": story.retryable_assets,
|
||||
@@ -32,6 +35,48 @@ def _job_status_from_story(story: Story) -> str:
|
||||
return "completed"
|
||||
|
||||
|
||||
def _job_progress(job: GenerationJob) -> dict[str, Any]:
|
||||
"""Resolve a compact progress summary for polling-oriented clients."""
|
||||
|
||||
if job.status == "failed":
|
||||
return {
|
||||
"progress_percent": 100,
|
||||
"progress_label": "生成失败",
|
||||
"is_terminal": True,
|
||||
}
|
||||
|
||||
if job.status in {"completed", "degraded_completed"}:
|
||||
return {
|
||||
"progress_percent": 100,
|
||||
"progress_label": "已完成" if job.status == "completed" else "降级完成",
|
||||
"is_terminal": True,
|
||||
}
|
||||
|
||||
progress_map: dict[str, tuple[int, str]] = {
|
||||
"request_accepted": (5, "已接收请求"),
|
||||
"context_prepared": (20, "上下文已准备"),
|
||||
"narrative_generated": (45, "正文已生成"),
|
||||
"story_saved": (60, "主记录已保存"),
|
||||
"provider_call_started": (65, "Provider 调用中"),
|
||||
"provider_call_succeeded": (72, "Provider 调用成功"),
|
||||
"provider_call_failed": (72, "Provider 调用失败,尝试恢复"),
|
||||
"cover_image_started": (75, "封面生成中"),
|
||||
"storybook_images_started": (75, "绘本插图生成中"),
|
||||
"audio_started": (75, "音频生成中"),
|
||||
"asset_retry_started": (25, "资源重试中"),
|
||||
"postprocessing_queued": (90, "后处理已排队"),
|
||||
"asset_generation_completed": (100, "资源已完成"),
|
||||
"asset_retry_completed": (100, "资源重试完成"),
|
||||
"generation_completed": (100, "生成完成"),
|
||||
}
|
||||
percent, label = progress_map.get(job.current_step, (10, "生成处理中"))
|
||||
return {
|
||||
"progress_percent": percent,
|
||||
"progress_label": label,
|
||||
"is_terminal": percent >= 100,
|
||||
}
|
||||
|
||||
|
||||
async def create_generation_job(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
@@ -131,3 +176,198 @@ async def finish_generation_job(
|
||||
await db.commit()
|
||||
await db.refresh(job)
|
||||
return job
|
||||
|
||||
|
||||
def generation_event_to_response(event: GenerationJobEvent) -> dict[str, Any]:
|
||||
"""Convert a generation event ORM object to an API response dict."""
|
||||
|
||||
return {
|
||||
"id": event.id,
|
||||
"job_id": event.job_id,
|
||||
"story_id": event.story_id,
|
||||
"event_type": event.event_type,
|
||||
"status": event.status,
|
||||
"message": event.message,
|
||||
"event_metadata": event.event_metadata or {},
|
||||
"created_at": event.created_at,
|
||||
}
|
||||
|
||||
|
||||
def generation_job_to_summary(job: GenerationJob) -> dict[str, Any]:
|
||||
"""Convert a generation job ORM object to an API summary dict."""
|
||||
|
||||
progress = _job_progress(job)
|
||||
return {
|
||||
"id": job.id,
|
||||
"story_id": job.story_id,
|
||||
"output_mode": job.output_mode,
|
||||
"input_type": job.input_type,
|
||||
"status": job.status,
|
||||
"current_step": job.current_step,
|
||||
**progress,
|
||||
"result_snapshot": job.result_snapshot or {},
|
||||
"error_message": job.error_message,
|
||||
"created_at": job.created_at,
|
||||
"updated_at": job.updated_at,
|
||||
}
|
||||
|
||||
|
||||
async def get_generation_job_detail(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
job_id: str,
|
||||
user_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Return a user-owned generation job with its ordered event stream."""
|
||||
|
||||
result = await db.execute(
|
||||
select(GenerationJob).where(
|
||||
GenerationJob.id == job_id,
|
||||
GenerationJob.user_id == user_id,
|
||||
)
|
||||
)
|
||||
job = result.scalar_one_or_none()
|
||||
if job is None:
|
||||
raise HTTPException(status_code=404, detail="Generation job not found")
|
||||
|
||||
events = (
|
||||
await db.execute(
|
||||
select(GenerationJobEvent)
|
||||
.where(GenerationJobEvent.job_id == job.id)
|
||||
.order_by(GenerationJobEvent.id)
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
return {
|
||||
**generation_job_to_summary(job),
|
||||
"request_payload": job.request_payload or {},
|
||||
"events": [generation_event_to_response(event) for event in events],
|
||||
}
|
||||
|
||||
|
||||
async def list_story_generation_jobs(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
story_id: int,
|
||||
user_id: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return recent generation jobs for a user-owned story."""
|
||||
|
||||
jobs = (
|
||||
await db.execute(
|
||||
select(GenerationJob)
|
||||
.where(
|
||||
GenerationJob.story_id == story_id,
|
||||
GenerationJob.user_id == user_id,
|
||||
)
|
||||
.order_by(desc(GenerationJob.created_at), desc(GenerationJob.id))
|
||||
)
|
||||
).scalars().all()
|
||||
return [generation_job_to_summary(job) for job in jobs]
|
||||
|
||||
|
||||
def _as_float(value: Any) -> float | None:
|
||||
if isinstance(value, int | float):
|
||||
return float(value)
|
||||
return None
|
||||
|
||||
|
||||
async def get_story_provider_stats(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
story_id: int,
|
||||
user_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Aggregate provider call telemetry from all user-owned jobs for one story."""
|
||||
|
||||
events = (
|
||||
await db.execute(
|
||||
select(GenerationJobEvent)
|
||||
.join(GenerationJob, GenerationJobEvent.job_id == GenerationJob.id)
|
||||
.where(
|
||||
GenerationJob.story_id == story_id,
|
||||
GenerationJob.user_id == user_id,
|
||||
GenerationJobEvent.event_type.in_(
|
||||
["provider_call_succeeded", "provider_call_failed"]
|
||||
),
|
||||
)
|
||||
.order_by(GenerationJobEvent.id)
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
by_key: dict[tuple[str, str], dict[str, Any]] = {}
|
||||
total_latency = 0.0
|
||||
latency_count = 0
|
||||
total_cost = 0.0
|
||||
successful_calls = 0
|
||||
failed_calls = 0
|
||||
|
||||
for event in events:
|
||||
metadata = event.event_metadata or {}
|
||||
capability = str(metadata.get("capability") or "unknown")
|
||||
adapter = str(metadata.get("adapter") or "unknown")
|
||||
key = (capability, adapter)
|
||||
bucket = by_key.setdefault(
|
||||
key,
|
||||
{
|
||||
"capability": capability,
|
||||
"adapter": adapter,
|
||||
"call_count": 0,
|
||||
"success_count": 0,
|
||||
"failure_count": 0,
|
||||
"latency_total": 0.0,
|
||||
"latency_count": 0,
|
||||
"estimated_cost_usd": 0.0,
|
||||
},
|
||||
)
|
||||
|
||||
bucket["call_count"] += 1
|
||||
latency = _as_float(metadata.get("latency_ms"))
|
||||
if latency is not None:
|
||||
bucket["latency_total"] += latency
|
||||
bucket["latency_count"] += 1
|
||||
total_latency += latency
|
||||
latency_count += 1
|
||||
|
||||
if event.event_type == "provider_call_succeeded":
|
||||
bucket["success_count"] += 1
|
||||
successful_calls += 1
|
||||
cost = _as_float(metadata.get("estimated_cost_usd")) or 0.0
|
||||
bucket["estimated_cost_usd"] += cost
|
||||
total_cost += cost
|
||||
else:
|
||||
bucket["failure_count"] += 1
|
||||
failed_calls += 1
|
||||
|
||||
by_provider = []
|
||||
for bucket in by_key.values():
|
||||
bucket_latency_count = bucket.pop("latency_count")
|
||||
bucket_latency_total = bucket.pop("latency_total")
|
||||
by_provider.append(
|
||||
{
|
||||
**bucket,
|
||||
"avg_latency_ms": (
|
||||
round(bucket_latency_total / bucket_latency_count, 2)
|
||||
if bucket_latency_count
|
||||
else None
|
||||
),
|
||||
"estimated_cost_usd": round(bucket["estimated_cost_usd"], 6),
|
||||
}
|
||||
)
|
||||
|
||||
by_provider.sort(
|
||||
key=lambda item: (
|
||||
str(item["capability"]),
|
||||
str(item["adapter"]),
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
"story_id": story_id,
|
||||
"total_calls": successful_calls + failed_calls,
|
||||
"successful_calls": successful_calls,
|
||||
"failed_calls": failed_calls,
|
||||
"avg_latency_ms": round(total_latency / latency_count, 2) if latency_count else None,
|
||||
"estimated_cost_usd": round(total_cost, 6),
|
||||
"by_provider": by_provider,
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ from app.core.logging import get_logger
|
||||
from app.services.adapters import AdapterConfig, AdapterRegistry
|
||||
from app.services.adapters.text.models import StoryOutput
|
||||
from app.services.cost_tracker import cost_tracker
|
||||
from app.services.generation_jobs import record_generation_event
|
||||
from app.services.provider_cache import get_providers
|
||||
from app.services.provider_metrics import health_checker, metrics_collector
|
||||
from app.services.provider_policy import (
|
||||
@@ -22,6 +23,7 @@ from app.services.provider_policy import (
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.db.admin_models import Provider
|
||||
from app.db.models import GenerationJob
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -36,6 +38,58 @@ _round_robin_counters: dict[ProviderType, int] = {
|
||||
_latency_cache: dict[str, float] = {}
|
||||
|
||||
|
||||
def _safe_estimated_cost(adapter) -> float:
|
||||
"""Return an adapter cost value that is safe to serialize in job events."""
|
||||
|
||||
try:
|
||||
return float(adapter.estimated_cost)
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
|
||||
async def _record_provider_event_if_present(
|
||||
db: AsyncSession | None,
|
||||
*,
|
||||
job: "GenerationJob | None",
|
||||
event_type: str,
|
||||
status: str,
|
||||
provider_type: ProviderType,
|
||||
adapter_name: str,
|
||||
strategy: RoutingStrategy,
|
||||
provider_id: str | None = None,
|
||||
story_id: int | None = None,
|
||||
latency_ms: int | None = None,
|
||||
estimated_cost: float | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
"""Append provider call telemetry to the active generation job."""
|
||||
|
||||
if db is None or job is None:
|
||||
return
|
||||
|
||||
await record_generation_event(
|
||||
db,
|
||||
job=job,
|
||||
story_id=story_id,
|
||||
event_type=event_type,
|
||||
status=status,
|
||||
message=(
|
||||
f"{provider_type} provider {adapter_name} {status}."
|
||||
if error is None
|
||||
else f"{provider_type} provider {adapter_name} failed."
|
||||
),
|
||||
metadata={
|
||||
"capability": provider_type,
|
||||
"adapter": adapter_name,
|
||||
"provider_id": provider_id,
|
||||
"strategy": strategy.value,
|
||||
"latency_ms": latency_ms,
|
||||
"estimated_cost_usd": estimated_cost,
|
||||
"error": error,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _get_api_key(config_ref: str | None, adapter_name: str) -> str:
|
||||
"""根据 config_ref 或适配器名称获取 API Key。"""
|
||||
# 优先使用 config_ref
|
||||
@@ -228,6 +282,8 @@ async def _route_with_failover(
|
||||
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
|
||||
db: AsyncSession | None = None,
|
||||
user_id: str | None = None,
|
||||
generation_job: "GenerationJob | None" = None,
|
||||
story_id: int | None = None,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
"""通用 provider failover 路由。
|
||||
@@ -237,6 +293,8 @@ async def _route_with_failover(
|
||||
strategy: 路由策略
|
||||
db: 数据库会话(可选,用于指标收集和熔断检查)
|
||||
user_id: 用户 ID(可选,用于成本追踪和预算检查)
|
||||
generation_job: 生成任务(可选,用于记录 provider 调用轨迹)
|
||||
story_id: 故事 ID(可选,用于关联 provider 事件)
|
||||
**kwargs: 传递给适配器的参数
|
||||
"""
|
||||
providers = await _get_providers_with_config(provider_type)
|
||||
@@ -274,7 +332,9 @@ async def _route_with_failover(
|
||||
errors.append(f"{name}: 适配器未注册")
|
||||
continue
|
||||
|
||||
provider_id = db_provider.id if db_provider else None
|
||||
provider_id = str(db_provider.id) if db_provider else None
|
||||
estimated_cost: float | None = None
|
||||
start_time: float | None = None
|
||||
|
||||
try:
|
||||
logger.debug(
|
||||
@@ -285,6 +345,20 @@ async def _route_with_failover(
|
||||
)
|
||||
|
||||
adapter = adapter_class(config)
|
||||
estimated_cost = _safe_estimated_cost(adapter)
|
||||
|
||||
await _record_provider_event_if_present(
|
||||
db,
|
||||
job=generation_job,
|
||||
story_id=story_id,
|
||||
event_type="provider_call_started",
|
||||
status="running",
|
||||
provider_type=provider_type,
|
||||
adapter_name=name,
|
||||
provider_id=provider_id,
|
||||
strategy=strategy,
|
||||
estimated_cost=estimated_cost,
|
||||
)
|
||||
|
||||
# 执行并计时
|
||||
start_time = time.time()
|
||||
@@ -301,7 +375,7 @@ async def _route_with_failover(
|
||||
provider_id=provider_id,
|
||||
success=True,
|
||||
latency_ms=latency_ms,
|
||||
cost_usd=adapter.estimated_cost,
|
||||
cost_usd=estimated_cost,
|
||||
)
|
||||
await health_checker.record_call_result(db, provider_id, success=True)
|
||||
|
||||
@@ -312,10 +386,24 @@ async def _route_with_failover(
|
||||
user_id=user_id,
|
||||
provider_name=name,
|
||||
capability=provider_type,
|
||||
estimated_cost=adapter.estimated_cost,
|
||||
estimated_cost=estimated_cost,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
await _record_provider_event_if_present(
|
||||
db,
|
||||
job=generation_job,
|
||||
story_id=story_id,
|
||||
event_type="provider_call_succeeded",
|
||||
status="succeeded",
|
||||
provider_type=provider_type,
|
||||
adapter_name=name,
|
||||
provider_id=provider_id,
|
||||
strategy=strategy,
|
||||
latency_ms=latency_ms,
|
||||
estimated_cost=estimated_cost,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"provider_success",
|
||||
provider_type=provider_type,
|
||||
@@ -326,6 +414,11 @@ async def _route_with_failover(
|
||||
|
||||
except Exception as exc:
|
||||
error_msg = str(exc)
|
||||
latency_ms = (
|
||||
int((time.time() - start_time) * 1000)
|
||||
if start_time is not None
|
||||
else None
|
||||
)
|
||||
logger.warning(
|
||||
"provider_failed",
|
||||
provider_type=provider_type,
|
||||
@@ -346,6 +439,21 @@ async def _route_with_failover(
|
||||
db, provider_id, success=False, error=error_msg
|
||||
)
|
||||
|
||||
await _record_provider_event_if_present(
|
||||
db,
|
||||
job=generation_job,
|
||||
story_id=story_id,
|
||||
event_type="provider_call_failed",
|
||||
status="failed",
|
||||
provider_type=provider_type,
|
||||
adapter_name=name,
|
||||
provider_id=provider_id,
|
||||
strategy=strategy,
|
||||
latency_ms=latency_ms,
|
||||
estimated_cost=estimated_cost,
|
||||
error=error_msg,
|
||||
)
|
||||
|
||||
raise ValueError(f"No {provider_type} provider succeeded. Errors: {' | '.join(errors)}")
|
||||
|
||||
|
||||
@@ -356,12 +464,16 @@ async def generate_story_content(
|
||||
memory_context: str | None = None,
|
||||
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
|
||||
db: AsyncSession | None = None,
|
||||
user_id: str | None = None,
|
||||
generation_job: "GenerationJob | None" = None,
|
||||
) -> StoryOutput:
|
||||
"""生成或润色故事,支持 failover。"""
|
||||
return await _route_with_failover(
|
||||
"text",
|
||||
strategy=strategy,
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
generation_job=generation_job,
|
||||
input_type=input_type,
|
||||
data=data,
|
||||
education_theme=education_theme,
|
||||
@@ -373,19 +485,42 @@ async def generate_image(
|
||||
prompt: str,
|
||||
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
|
||||
db: AsyncSession | None = None,
|
||||
user_id: str | None = None,
|
||||
generation_job: "GenerationJob | None" = None,
|
||||
story_id: int | None = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""生成图片,返回 URL,支持 failover。"""
|
||||
return await _route_with_failover("image", strategy=strategy, db=db, prompt=prompt, **kwargs)
|
||||
return await _route_with_failover(
|
||||
"image",
|
||||
strategy=strategy,
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
generation_job=generation_job,
|
||||
story_id=story_id,
|
||||
prompt=prompt,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def text_to_speech(
|
||||
text: str,
|
||||
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
|
||||
db: AsyncSession | None = None,
|
||||
user_id: str | None = None,
|
||||
generation_job: "GenerationJob | None" = None,
|
||||
story_id: int | None = None,
|
||||
) -> bytes:
|
||||
"""文本转语音,返回 MP3 bytes,支持 failover。"""
|
||||
return await _route_with_failover("tts", strategy=strategy, db=db, text=text)
|
||||
return await _route_with_failover(
|
||||
"tts",
|
||||
strategy=strategy,
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
generation_job=generation_job,
|
||||
story_id=story_id,
|
||||
text=text,
|
||||
)
|
||||
|
||||
|
||||
async def generate_storybook(
|
||||
@@ -395,6 +530,8 @@ async def generate_storybook(
|
||||
memory_context: str | None = None,
|
||||
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
|
||||
db: AsyncSession | None = None,
|
||||
user_id: str | None = None,
|
||||
generation_job: "GenerationJob | None" = None,
|
||||
):
|
||||
"""生成分页故事书,支持 failover。"""
|
||||
from app.services.adapters.storybook.primary import Storybook
|
||||
@@ -403,6 +540,8 @@ async def generate_storybook(
|
||||
"storybook",
|
||||
strategy=strategy,
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
generation_job=generation_job,
|
||||
keywords=keywords,
|
||||
page_count=page_count,
|
||||
education_theme=education_theme,
|
||||
|
||||
@@ -67,6 +67,43 @@ class AssetCompletionResult:
|
||||
return self.status == StoryAssetStatus.READY and self.error is None
|
||||
|
||||
|
||||
async def _record_job_event_if_present(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
job,
|
||||
event_type: str,
|
||||
status: str,
|
||||
story_id: int | None = None,
|
||||
message: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> None:
|
||||
"""Append a workflow event when the caller is running under a tracked job."""
|
||||
|
||||
if job is None:
|
||||
return
|
||||
|
||||
await record_generation_event(
|
||||
db,
|
||||
job=job,
|
||||
story_id=story_id,
|
||||
event_type=event_type,
|
||||
status=status,
|
||||
message=message,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
def _asset_result_metadata(result: AssetCompletionResult) -> dict:
|
||||
"""Build JSON-safe metadata for asset workflow events."""
|
||||
|
||||
return {
|
||||
"asset": result.asset,
|
||||
"status": result.status.value,
|
||||
"error": result.error,
|
||||
"blocks_main_result": result.blocks_main_result,
|
||||
}
|
||||
|
||||
|
||||
def _build_storybook_error_message(
|
||||
*,
|
||||
cover_failed: bool,
|
||||
@@ -125,6 +162,7 @@ async def _prepare_generation_context(
|
||||
universe_id: str | None,
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
job=None,
|
||||
) -> tuple[str | None, str | None, str]:
|
||||
"""Validate ownership and build the shared generation context."""
|
||||
|
||||
@@ -136,6 +174,18 @@ async def _prepare_generation_context(
|
||||
resolved_universe_id,
|
||||
db,
|
||||
)
|
||||
await _record_job_event_if_present(
|
||||
db,
|
||||
job=job,
|
||||
event_type="context_prepared",
|
||||
status="succeeded",
|
||||
message="Profile, universe, and memory context were prepared.",
|
||||
metadata={
|
||||
"profile_id": resolved_profile_id,
|
||||
"universe_id": resolved_universe_id,
|
||||
"has_memory_context": bool(memory_context),
|
||||
},
|
||||
)
|
||||
return resolved_profile_id, resolved_universe_id, memory_context
|
||||
|
||||
|
||||
@@ -173,6 +223,7 @@ async def _persist_text_story_result(
|
||||
profile_id: str | None,
|
||||
universe_id: str | None,
|
||||
db: AsyncSession,
|
||||
job=None,
|
||||
) -> Story:
|
||||
"""Persist generated text content as the unified story record."""
|
||||
|
||||
@@ -195,6 +246,20 @@ async def _persist_text_story_result(
|
||||
await db.commit()
|
||||
await db.refresh(story)
|
||||
_trigger_story_postprocessing(story)
|
||||
await _record_job_event_if_present(
|
||||
db,
|
||||
job=job,
|
||||
story_id=story.id,
|
||||
event_type="story_saved",
|
||||
status="succeeded",
|
||||
message="Readable story record was saved.",
|
||||
metadata={
|
||||
"mode": story.mode,
|
||||
"generation_status": story.generation_status,
|
||||
"image_status": story.image_status,
|
||||
"audio_status": story.audio_status,
|
||||
},
|
||||
)
|
||||
return story
|
||||
|
||||
|
||||
@@ -229,21 +294,47 @@ def _storybook_pages_to_response(pages_data: list[dict]) -> list[StorybookPageRe
|
||||
async def _generate_storybook_image_assets(
|
||||
storybook: Storybook,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str,
|
||||
job=None,
|
||||
) -> tuple[str | None, bool, list[int]]:
|
||||
"""Generate storybook cover and page images before persistence."""
|
||||
|
||||
final_cover_url = storybook.cover_url
|
||||
cover_failed = False
|
||||
failed_pages: list[int] = []
|
||||
completed_pages: list[int] = []
|
||||
attempted_cover = bool(storybook.cover_prompt and not storybook.cover_url)
|
||||
attempted_pages = [
|
||||
page.page_number
|
||||
for page in storybook.pages
|
||||
if page.image_prompt and not page.image_url
|
||||
]
|
||||
|
||||
logger.info("storybook_parallel_generation_start", page_count=len(storybook.pages))
|
||||
await _record_job_event_if_present(
|
||||
db,
|
||||
job=job,
|
||||
event_type="storybook_images_started",
|
||||
status="running",
|
||||
message="Storybook cover and page image generation started.",
|
||||
metadata={
|
||||
"attempted_cover": attempted_cover,
|
||||
"attempted_pages": attempted_pages,
|
||||
},
|
||||
)
|
||||
|
||||
async def _gen_cover() -> str | None:
|
||||
nonlocal cover_failed
|
||||
|
||||
if storybook.cover_prompt and not storybook.cover_url:
|
||||
try:
|
||||
return await generate_image(storybook.cover_prompt, db=db)
|
||||
return await generate_image(
|
||||
storybook.cover_prompt,
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
generation_job=job,
|
||||
)
|
||||
except Exception as exc:
|
||||
cover_failed = True
|
||||
logger.warning("cover_gen_failed", error=str(exc))
|
||||
@@ -254,7 +345,13 @@ async def _generate_storybook_image_assets(
|
||||
return
|
||||
|
||||
try:
|
||||
page.image_url = await generate_image(page.image_prompt, db=db)
|
||||
page.image_url = await generate_image(
|
||||
page.image_prompt,
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
generation_job=job,
|
||||
)
|
||||
completed_pages.append(page.page_number)
|
||||
except Exception as exc:
|
||||
failed_pages.append(page.page_number)
|
||||
logger.warning("page_gen_failed", page=page.page_number, error=str(exc))
|
||||
@@ -270,6 +367,57 @@ async def _generate_storybook_image_assets(
|
||||
final_cover_url = cover_result
|
||||
|
||||
logger.info("storybook_parallel_generation_complete")
|
||||
if attempted_cover:
|
||||
await _record_job_event_if_present(
|
||||
db,
|
||||
job=job,
|
||||
event_type=(
|
||||
"storybook_cover_image_failed"
|
||||
if cover_failed
|
||||
else "storybook_cover_image_succeeded"
|
||||
),
|
||||
status="failed" if cover_failed else "succeeded",
|
||||
message=(
|
||||
"Storybook cover image generation failed."
|
||||
if cover_failed
|
||||
else "Storybook cover image was generated."
|
||||
),
|
||||
metadata={"asset": "image", "scope": "cover"},
|
||||
)
|
||||
|
||||
for page_number in sorted(completed_pages):
|
||||
await _record_job_event_if_present(
|
||||
db,
|
||||
job=job,
|
||||
event_type="storybook_page_image_succeeded",
|
||||
status="succeeded",
|
||||
message="Storybook page image was generated.",
|
||||
metadata={"asset": "image", "scope": "page", "page_number": page_number},
|
||||
)
|
||||
|
||||
for page_number in sorted(failed_pages):
|
||||
await _record_job_event_if_present(
|
||||
db,
|
||||
job=job,
|
||||
event_type="storybook_page_image_failed",
|
||||
status="failed",
|
||||
message="Storybook page image generation failed.",
|
||||
metadata={"asset": "image", "scope": "page", "page_number": page_number},
|
||||
)
|
||||
|
||||
await _record_job_event_if_present(
|
||||
db,
|
||||
job=job,
|
||||
event_type="storybook_images_completed",
|
||||
status="failed" if cover_failed or failed_pages else "succeeded",
|
||||
message="Storybook image generation finished.",
|
||||
metadata={
|
||||
"asset": "image",
|
||||
"attempted_cover": attempted_cover,
|
||||
"completed_pages": sorted(completed_pages),
|
||||
"failed_pages": sorted(failed_pages),
|
||||
},
|
||||
)
|
||||
return final_cover_url, cover_failed, failed_pages
|
||||
|
||||
|
||||
@@ -284,6 +432,7 @@ async def _persist_storybook_result(
|
||||
cover_failed: bool,
|
||||
failed_pages: list[int],
|
||||
db: AsyncSession,
|
||||
job=None,
|
||||
) -> tuple[Story, list[dict]]:
|
||||
"""Persist generated storybook content as the unified story record."""
|
||||
|
||||
@@ -317,6 +466,21 @@ async def _persist_storybook_result(
|
||||
await db.commit()
|
||||
await db.refresh(story)
|
||||
_trigger_story_postprocessing(story)
|
||||
await _record_job_event_if_present(
|
||||
db,
|
||||
job=job,
|
||||
story_id=story.id,
|
||||
event_type="story_saved",
|
||||
status="succeeded",
|
||||
message="Storybook record was saved.",
|
||||
metadata={
|
||||
"mode": story.mode,
|
||||
"page_count": len(pages_data),
|
||||
"generation_status": story.generation_status,
|
||||
"image_status": story.image_status,
|
||||
"audio_status": story.audio_status,
|
||||
},
|
||||
)
|
||||
return story, pages_data
|
||||
|
||||
|
||||
@@ -327,6 +491,7 @@ async def _complete_cover_image_asset(
|
||||
raise_on_failure: bool = False,
|
||||
last_error_prefix: str | None = None,
|
||||
log_event: str = "cover_asset_generation_failed",
|
||||
job=None,
|
||||
) -> AssetCompletionResult:
|
||||
"""Generate or retry a text story cover through one asset workflow."""
|
||||
|
||||
@@ -335,18 +500,43 @@ async def _complete_cover_image_asset(
|
||||
|
||||
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
|
||||
await db.commit()
|
||||
await _record_job_event_if_present(
|
||||
db,
|
||||
job=job,
|
||||
story_id=story.id,
|
||||
event_type="cover_image_started",
|
||||
status="running",
|
||||
message="Cover image generation started.",
|
||||
metadata={"asset": "image", "cover_prompt_present": True},
|
||||
)
|
||||
|
||||
try:
|
||||
image_url = await generate_image(story.cover_prompt, db=db)
|
||||
image_url = await generate_image(
|
||||
story.cover_prompt,
|
||||
db=db,
|
||||
user_id=story.user_id,
|
||||
generation_job=job,
|
||||
story_id=story.id,
|
||||
)
|
||||
story.image_url = image_url
|
||||
sync_story_status(story, image_status=StoryAssetStatus.READY)
|
||||
await db.commit()
|
||||
return AssetCompletionResult(
|
||||
result = AssetCompletionResult(
|
||||
asset="cover_image",
|
||||
status=StoryAssetStatus.READY,
|
||||
value=image_url,
|
||||
blocks_main_result=raise_on_failure,
|
||||
)
|
||||
await _record_job_event_if_present(
|
||||
db,
|
||||
job=job,
|
||||
story_id=story.id,
|
||||
event_type="cover_image_succeeded",
|
||||
status="succeeded",
|
||||
message="Cover image was generated.",
|
||||
metadata=_asset_result_metadata(result),
|
||||
)
|
||||
return result
|
||||
except Exception as exc:
|
||||
provider_error = str(exc)
|
||||
last_error = (
|
||||
@@ -362,18 +552,28 @@ async def _complete_cover_image_asset(
|
||||
await db.commit()
|
||||
logger.warning(log_event, story_id=story.id, error=provider_error)
|
||||
|
||||
result = AssetCompletionResult(
|
||||
asset="cover_image",
|
||||
status=StoryAssetStatus.FAILED,
|
||||
error=provider_error,
|
||||
blocks_main_result=raise_on_failure,
|
||||
)
|
||||
await _record_job_event_if_present(
|
||||
db,
|
||||
job=job,
|
||||
story_id=story.id,
|
||||
event_type="cover_image_failed",
|
||||
status="failed",
|
||||
message="Cover image generation failed.",
|
||||
metadata=_asset_result_metadata(result),
|
||||
)
|
||||
if raise_on_failure:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Image generation failed: {provider_error}",
|
||||
) from exc
|
||||
|
||||
return AssetCompletionResult(
|
||||
asset="cover_image",
|
||||
status=StoryAssetStatus.FAILED,
|
||||
error=provider_error,
|
||||
blocks_main_result=raise_on_failure,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _get_storybook_pages_data(story: Story) -> list[dict]:
|
||||
@@ -385,6 +585,8 @@ def _get_storybook_pages_data(story: Story) -> list[dict]:
|
||||
async def _complete_storybook_image_assets(
|
||||
story: Story,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
job=None,
|
||||
) -> AssetCompletionResult:
|
||||
"""Complete missing cover/page images for a persisted storybook."""
|
||||
|
||||
@@ -397,13 +599,38 @@ async def _complete_storybook_image_assets(
|
||||
|
||||
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
|
||||
await db.commit()
|
||||
await _record_job_event_if_present(
|
||||
db,
|
||||
job=job,
|
||||
story_id=story.id,
|
||||
event_type="storybook_images_started",
|
||||
status="running",
|
||||
message="Storybook missing image completion started.",
|
||||
metadata={"asset": "image"},
|
||||
)
|
||||
|
||||
cover_failed = False
|
||||
failed_pages: list[int] = []
|
||||
completed_pages: list[int] = []
|
||||
|
||||
if story.cover_prompt and not story.image_url:
|
||||
try:
|
||||
story.image_url = await generate_image(story.cover_prompt, db=db)
|
||||
story.image_url = await generate_image(
|
||||
story.cover_prompt,
|
||||
db=db,
|
||||
user_id=story.user_id,
|
||||
generation_job=job,
|
||||
story_id=story.id,
|
||||
)
|
||||
await _record_job_event_if_present(
|
||||
db,
|
||||
job=job,
|
||||
story_id=story.id,
|
||||
event_type="storybook_cover_image_succeeded",
|
||||
status="succeeded",
|
||||
message="Storybook cover image was generated.",
|
||||
metadata={"asset": "image", "scope": "cover"},
|
||||
)
|
||||
except Exception as exc:
|
||||
cover_failed = True
|
||||
logger.warning(
|
||||
@@ -411,13 +638,40 @@ async def _complete_storybook_image_assets(
|
||||
story_id=story.id,
|
||||
error=str(exc),
|
||||
)
|
||||
await _record_job_event_if_present(
|
||||
db,
|
||||
job=job,
|
||||
story_id=story.id,
|
||||
event_type="storybook_cover_image_failed",
|
||||
status="failed",
|
||||
message="Storybook cover image generation failed.",
|
||||
metadata={"asset": "image", "scope": "cover", "error": str(exc)},
|
||||
)
|
||||
|
||||
for page in pages_data:
|
||||
if not page.get("image_prompt") or page.get("image_url"):
|
||||
continue
|
||||
|
||||
try:
|
||||
page["image_url"] = await generate_image(page["image_prompt"], db=db)
|
||||
page["image_url"] = await generate_image(
|
||||
page["image_prompt"],
|
||||
db=db,
|
||||
user_id=story.user_id,
|
||||
generation_job=job,
|
||||
story_id=story.id,
|
||||
)
|
||||
page_number = page.get("page_number")
|
||||
if isinstance(page_number, int):
|
||||
completed_pages.append(page_number)
|
||||
await _record_job_event_if_present(
|
||||
db,
|
||||
job=job,
|
||||
story_id=story.id,
|
||||
event_type="storybook_page_image_succeeded",
|
||||
status="succeeded",
|
||||
message="Storybook page image was generated.",
|
||||
metadata={"asset": "image", "scope": "page", "page_number": page_number},
|
||||
)
|
||||
except Exception as exc:
|
||||
page_number = page.get("page_number")
|
||||
if isinstance(page_number, int):
|
||||
@@ -428,6 +682,20 @@ async def _complete_storybook_image_assets(
|
||||
page=page_number,
|
||||
error=str(exc),
|
||||
)
|
||||
await _record_job_event_if_present(
|
||||
db,
|
||||
job=job,
|
||||
story_id=story.id,
|
||||
event_type="storybook_page_image_failed",
|
||||
status="failed",
|
||||
message="Storybook page image generation failed.",
|
||||
metadata={
|
||||
"asset": "image",
|
||||
"scope": "page",
|
||||
"page_number": page_number,
|
||||
"error": str(exc),
|
||||
},
|
||||
)
|
||||
|
||||
story.pages = pages_data
|
||||
error_message = _build_storybook_error_message(
|
||||
@@ -446,12 +714,26 @@ async def _complete_storybook_image_assets(
|
||||
last_error=error_message,
|
||||
)
|
||||
await db.commit()
|
||||
return AssetCompletionResult(
|
||||
result = AssetCompletionResult(
|
||||
asset="storybook_images",
|
||||
status=image_status,
|
||||
value=story.image_url,
|
||||
error=error_message,
|
||||
)
|
||||
await _record_job_event_if_present(
|
||||
db,
|
||||
job=job,
|
||||
story_id=story.id,
|
||||
event_type="storybook_images_completed",
|
||||
status="failed" if error_message else "succeeded",
|
||||
message="Storybook image completion finished.",
|
||||
metadata={
|
||||
**_asset_result_metadata(result),
|
||||
"completed_pages": sorted(completed_pages),
|
||||
"failed_pages": sorted(failed_pages),
|
||||
},
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def _read_cached_audio_asset(story: Story, db: AsyncSession) -> bytes | None:
|
||||
@@ -482,6 +764,7 @@ async def _complete_audio_asset(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
raise_on_failure: bool = True,
|
||||
job=None,
|
||||
) -> AssetCompletionResult:
|
||||
"""Complete TTS audio generation through one asset workflow."""
|
||||
|
||||
@@ -490,32 +773,67 @@ async def _complete_audio_asset(
|
||||
|
||||
cached_audio = await _read_cached_audio_asset(story, db)
|
||||
if cached_audio is not None:
|
||||
return AssetCompletionResult(
|
||||
result = AssetCompletionResult(
|
||||
asset="audio",
|
||||
status=StoryAssetStatus.READY,
|
||||
value=cached_audio,
|
||||
blocks_main_result=raise_on_failure,
|
||||
)
|
||||
await _record_job_event_if_present(
|
||||
db,
|
||||
job=job,
|
||||
story_id=story.id,
|
||||
event_type="audio_cache_hit",
|
||||
status="succeeded",
|
||||
message="Cached story audio was reused.",
|
||||
metadata=_asset_result_metadata(result),
|
||||
)
|
||||
return result
|
||||
|
||||
from app.services.provider_router import text_to_speech
|
||||
|
||||
sync_story_status(story, audio_status=StoryAssetStatus.GENERATING)
|
||||
await db.commit()
|
||||
await _record_job_event_if_present(
|
||||
db,
|
||||
job=job,
|
||||
story_id=story.id,
|
||||
event_type="audio_started",
|
||||
status="running",
|
||||
message="Story audio generation started.",
|
||||
metadata={"asset": "audio"},
|
||||
)
|
||||
|
||||
try:
|
||||
audio_data = await text_to_speech(story.story_text, db=db)
|
||||
audio_data = await text_to_speech(
|
||||
story.story_text,
|
||||
db=db,
|
||||
user_id=story.user_id,
|
||||
generation_job=job,
|
||||
story_id=story.id,
|
||||
)
|
||||
story.audio_path = write_story_audio_cache(story.id, audio_data)
|
||||
sync_story_status(
|
||||
story,
|
||||
audio_status=StoryAssetStatus.READY,
|
||||
)
|
||||
await db.commit()
|
||||
return AssetCompletionResult(
|
||||
result = AssetCompletionResult(
|
||||
asset="audio",
|
||||
status=StoryAssetStatus.READY,
|
||||
value=audio_data,
|
||||
blocks_main_result=raise_on_failure,
|
||||
)
|
||||
await _record_job_event_if_present(
|
||||
db,
|
||||
job=job,
|
||||
story_id=story.id,
|
||||
event_type="audio_succeeded",
|
||||
status="succeeded",
|
||||
message="Story audio was generated and cached.",
|
||||
metadata=_asset_result_metadata(result),
|
||||
)
|
||||
return result
|
||||
except Exception as exc:
|
||||
provider_error = str(exc)
|
||||
story.audio_path = None
|
||||
@@ -527,18 +845,28 @@ async def _complete_audio_asset(
|
||||
await db.commit()
|
||||
logger.error("audio_generation_failed", story_id=story.id, error=provider_error)
|
||||
|
||||
result = AssetCompletionResult(
|
||||
asset="audio",
|
||||
status=StoryAssetStatus.FAILED,
|
||||
error=provider_error,
|
||||
blocks_main_result=raise_on_failure,
|
||||
)
|
||||
await _record_job_event_if_present(
|
||||
db,
|
||||
job=job,
|
||||
story_id=story.id,
|
||||
event_type="audio_failed",
|
||||
status="failed",
|
||||
message="Story audio generation failed.",
|
||||
metadata=_asset_result_metadata(result),
|
||||
)
|
||||
if raise_on_failure:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Audio generation failed: {provider_error}",
|
||||
) from exc
|
||||
|
||||
return AssetCompletionResult(
|
||||
asset="audio",
|
||||
status=StoryAssetStatus.FAILED,
|
||||
error=provider_error,
|
||||
blocks_main_result=raise_on_failure,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def validate_profile_and_universe(
|
||||
@@ -586,6 +914,8 @@ async def generate_and_save_story(
|
||||
request: GenerateRequest,
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
job=None,
|
||||
) -> Story:
|
||||
"""Generate generic story content and save to DB."""
|
||||
profile_id, universe_id, memory_context = await _prepare_generation_context(
|
||||
@@ -593,21 +923,32 @@ async def generate_and_save_story(
|
||||
universe_id=request.universe_id,
|
||||
user_id=user_id,
|
||||
db=db,
|
||||
job=job,
|
||||
)
|
||||
|
||||
try:
|
||||
result = await generate_story_content(
|
||||
input_type=request.type,
|
||||
data=request.data,
|
||||
education_theme=request.education_theme,
|
||||
memory_context=memory_context,
|
||||
db=db,
|
||||
)
|
||||
education_theme=request.education_theme,
|
||||
memory_context=memory_context,
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
generation_job=job,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Story generation failed, please try again.",
|
||||
) from exc
|
||||
await _record_job_event_if_present(
|
||||
db,
|
||||
job=job,
|
||||
event_type="narrative_generated",
|
||||
status="succeeded",
|
||||
message="Story narrative was generated.",
|
||||
metadata={"mode": result.mode, "title": result.title},
|
||||
)
|
||||
|
||||
return await _persist_text_story_result(
|
||||
result=result,
|
||||
@@ -615,6 +956,7 @@ async def generate_and_save_story(
|
||||
profile_id=profile_id,
|
||||
universe_id=universe_id,
|
||||
db=db,
|
||||
job=job,
|
||||
)
|
||||
|
||||
|
||||
@@ -622,9 +964,11 @@ async def generate_full_story_service(
|
||||
request: GenerateRequest,
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
job=None,
|
||||
) -> FullStoryResponse:
|
||||
"""Generate story with parallel image generation."""
|
||||
story = await generate_and_save_story(request, user_id, db)
|
||||
story = await generate_and_save_story(request, user_id, db, job=job)
|
||||
image_url: str | None = None
|
||||
errors: dict[str, str | None] = {}
|
||||
|
||||
@@ -633,6 +977,7 @@ async def generate_full_story_service(
|
||||
story,
|
||||
db,
|
||||
log_event="image_generation_failed",
|
||||
job=job,
|
||||
)
|
||||
if image_result.succeeded and isinstance(image_result.value, str):
|
||||
image_url = image_result.value
|
||||
@@ -651,6 +996,7 @@ async def generate_full_story_service(
|
||||
child_profile_id=story.child_profile_id,
|
||||
universe_id=story.universe_id,
|
||||
generation_status=story.generation_status,
|
||||
text_status=story.text_status,
|
||||
image_status=story.image_status,
|
||||
audio_status=story.audio_status,
|
||||
last_error=story.last_error,
|
||||
@@ -662,6 +1008,8 @@ async def generate_storybook_service(
|
||||
request: StorybookRequest,
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
job=None,
|
||||
) -> StorybookResponse:
|
||||
"""Generate storybook with parallel image generation for pages."""
|
||||
profile_id, universe_id, memory_context = await _prepare_generation_context(
|
||||
@@ -669,6 +1017,7 @@ async def generate_storybook_service(
|
||||
universe_id=request.universe_id,
|
||||
user_id=user_id,
|
||||
db=db,
|
||||
job=job,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -684,13 +1033,27 @@ async def generate_storybook_service(
|
||||
storybook = await generate_storybook(
|
||||
keywords=request.keywords,
|
||||
page_count=request.page_count,
|
||||
education_theme=request.education_theme,
|
||||
memory_context=memory_context,
|
||||
db=db,
|
||||
)
|
||||
except Exception as e:
|
||||
education_theme=request.education_theme,
|
||||
memory_context=memory_context,
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
generation_job=job,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("storybook_generation_failed", error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"故事书生成失败: {e}")
|
||||
await _record_job_event_if_present(
|
||||
db,
|
||||
job=job,
|
||||
event_type="narrative_generated",
|
||||
status="succeeded",
|
||||
message="Storybook narrative and page plan were generated.",
|
||||
metadata={
|
||||
"mode": "storybook",
|
||||
"title": storybook.title,
|
||||
"page_count": len(storybook.pages),
|
||||
},
|
||||
)
|
||||
|
||||
final_cover_url = storybook.cover_url
|
||||
cover_failed = False
|
||||
@@ -701,7 +1064,12 @@ async def generate_storybook_service(
|
||||
final_cover_url,
|
||||
cover_failed,
|
||||
failed_pages,
|
||||
) = await _generate_storybook_image_assets(storybook, db)
|
||||
) = await _generate_storybook_image_assets(
|
||||
storybook,
|
||||
db,
|
||||
user_id=user_id,
|
||||
job=job,
|
||||
)
|
||||
|
||||
story, pages_data = await _persist_storybook_result(
|
||||
storybook=storybook,
|
||||
@@ -713,6 +1081,7 @@ async def generate_storybook_service(
|
||||
cover_failed=cover_failed,
|
||||
failed_pages=failed_pages,
|
||||
db=db,
|
||||
job=job,
|
||||
)
|
||||
|
||||
response_pages = _storybook_pages_to_response(pages_data)
|
||||
@@ -726,6 +1095,7 @@ async def generate_storybook_service(
|
||||
cover_prompt=storybook.cover_prompt,
|
||||
cover_url=final_cover_url,
|
||||
generation_status=story.generation_status,
|
||||
text_status=story.text_status,
|
||||
image_status=story.image_status,
|
||||
audio_status=story.audio_status,
|
||||
last_error=story.last_error,
|
||||
@@ -797,6 +1167,7 @@ async def _generate_generation_service_with_job(
|
||||
),
|
||||
user_id,
|
||||
db,
|
||||
job=job,
|
||||
)
|
||||
if storybook.id is None:
|
||||
raise HTTPException(status_code=500, detail="Storybook generation did not persist.")
|
||||
@@ -812,6 +1183,7 @@ async def _generate_generation_service_with_job(
|
||||
)
|
||||
return GenerationResponse(
|
||||
id=storybook.id,
|
||||
generation_job_id=job.id,
|
||||
title=storybook.title,
|
||||
mode="storybook",
|
||||
pages=storybook.pages,
|
||||
@@ -821,6 +1193,7 @@ async def _generate_generation_service_with_job(
|
||||
main_character=storybook.main_character,
|
||||
art_style=storybook.art_style,
|
||||
generation_status=storybook.generation_status,
|
||||
text_status=saved_story.text_status,
|
||||
image_status=storybook.image_status,
|
||||
audio_status=storybook.audio_status,
|
||||
last_error=storybook.last_error,
|
||||
@@ -838,7 +1211,7 @@ async def _generate_generation_service_with_job(
|
||||
)
|
||||
|
||||
if request.generate_images:
|
||||
story = await generate_full_story_service(generate_request, user_id, db)
|
||||
story = await generate_full_story_service(generate_request, user_id, db, job=job)
|
||||
saved_story = await get_story_detail(story.id, user_id, db)
|
||||
await _record_postprocessing_event_if_needed(db, job=job, story=saved_story)
|
||||
await finish_generation_job(
|
||||
@@ -850,6 +1223,7 @@ async def _generate_generation_service_with_job(
|
||||
)
|
||||
return GenerationResponse(
|
||||
id=story.id,
|
||||
generation_job_id=job.id,
|
||||
title=story.title,
|
||||
mode=story.mode,
|
||||
story_text=story.story_text,
|
||||
@@ -859,6 +1233,7 @@ async def _generate_generation_service_with_job(
|
||||
audio_ready=story.audio_ready,
|
||||
errors=story.errors,
|
||||
generation_status=story.generation_status,
|
||||
text_status=saved_story.text_status,
|
||||
image_status=story.image_status,
|
||||
audio_status=story.audio_status,
|
||||
last_error=story.last_error,
|
||||
@@ -867,7 +1242,7 @@ async def _generate_generation_service_with_job(
|
||||
retryable_assets=saved_story.retryable_assets,
|
||||
)
|
||||
|
||||
story = await generate_and_save_story(generate_request, user_id, db)
|
||||
story = await generate_and_save_story(generate_request, user_id, db, job=job)
|
||||
await _record_postprocessing_event_if_needed(db, job=job, story=story)
|
||||
await finish_generation_job(
|
||||
db,
|
||||
@@ -878,6 +1253,7 @@ async def _generate_generation_service_with_job(
|
||||
)
|
||||
return GenerationResponse(
|
||||
id=story.id,
|
||||
generation_job_id=job.id,
|
||||
title=story.title,
|
||||
mode=story.mode,
|
||||
story_text=story.story_text,
|
||||
@@ -885,6 +1261,7 @@ async def _generate_generation_service_with_job(
|
||||
image_url=story.image_url,
|
||||
cover_url=story.image_url,
|
||||
generation_status=story.generation_status,
|
||||
text_status=story.text_status,
|
||||
image_status=story.image_status,
|
||||
audio_status=story.audio_status,
|
||||
last_error=story.last_error,
|
||||
@@ -954,7 +1331,7 @@ async def create_story_from_result(
|
||||
)
|
||||
|
||||
|
||||
async def _retry_cover_image_asset(story: Story, db: AsyncSession) -> None:
|
||||
async def _retry_cover_image_asset(story: Story, db: AsyncSession, *, job=None) -> None:
|
||||
"""Retry cover generation for a text story."""
|
||||
|
||||
await _complete_cover_image_asset(
|
||||
@@ -962,19 +1339,25 @@ async def _retry_cover_image_asset(story: Story, db: AsyncSession) -> None:
|
||||
db,
|
||||
last_error_prefix="封面生成失败",
|
||||
log_event="cover_asset_retry_failed",
|
||||
job=job,
|
||||
)
|
||||
|
||||
|
||||
async def _retry_storybook_image_assets(story: Story, db: AsyncSession) -> None:
|
||||
async def _retry_storybook_image_assets(
|
||||
story: Story,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
job=None,
|
||||
) -> None:
|
||||
"""Retry missing storybook cover/page images."""
|
||||
|
||||
await _complete_storybook_image_assets(story, db)
|
||||
await _complete_storybook_image_assets(story, db, job=job)
|
||||
|
||||
|
||||
async def _retry_audio_asset(story: Story, db: AsyncSession) -> None:
|
||||
async def _retry_audio_asset(story: Story, db: AsyncSession, *, job=None) -> None:
|
||||
"""Retry audio generation while preserving persisted status on provider failure."""
|
||||
|
||||
await _complete_audio_asset(story, db, raise_on_failure=False)
|
||||
await _complete_audio_asset(story, db, raise_on_failure=False, job=job)
|
||||
|
||||
|
||||
async def retry_story_assets(
|
||||
@@ -1009,12 +1392,12 @@ async def retry_story_assets(
|
||||
|
||||
if "image" in requested_assets:
|
||||
if story.mode == "storybook":
|
||||
await _retry_storybook_image_assets(story, db)
|
||||
await _retry_storybook_image_assets(story, db, job=job)
|
||||
else:
|
||||
await _retry_cover_image_asset(story, db)
|
||||
await _retry_cover_image_asset(story, db, job=job)
|
||||
|
||||
if "audio" in requested_assets:
|
||||
await _retry_audio_asset(story, db)
|
||||
await _retry_audio_asset(story, db, job=job)
|
||||
|
||||
story = await get_story_detail(story_id, user_id, db)
|
||||
await finish_generation_job(
|
||||
@@ -1075,6 +1458,7 @@ async def generate_story_cover(
|
||||
db,
|
||||
raise_on_failure=True,
|
||||
log_event="cover_generation_failed",
|
||||
job=job,
|
||||
)
|
||||
story = await get_story_detail(story_id, user_id, db)
|
||||
await finish_generation_job(
|
||||
@@ -1121,7 +1505,12 @@ async def generate_story_audio(
|
||||
|
||||
try:
|
||||
story = await get_story_detail(story_id, user_id, db)
|
||||
audio_result = await _complete_audio_asset(story, db, raise_on_failure=True)
|
||||
audio_result = await _complete_audio_asset(
|
||||
story,
|
||||
db,
|
||||
raise_on_failure=True,
|
||||
job=job,
|
||||
)
|
||||
story = await get_story_detail(story_id, user_id, db)
|
||||
await finish_generation_job(
|
||||
db,
|
||||
|
||||
@@ -10,6 +10,7 @@ class StoryGenerationStatus(str, Enum):
|
||||
"""Overall story generation lifecycle."""
|
||||
|
||||
NARRATIVE_READY = "narrative_ready"
|
||||
PARTIAL_READY = "partial_ready"
|
||||
ASSETS_GENERATING = "assets_generating"
|
||||
COMPLETED = "completed"
|
||||
DEGRADED_COMPLETED = "degraded_completed"
|
||||
@@ -30,7 +31,10 @@ class StoryLike(Protocol):
|
||||
|
||||
story_text: str | None
|
||||
pages: list[dict] | None
|
||||
cover_prompt: str | None
|
||||
image_url: str | None
|
||||
generation_status: str
|
||||
text_status: str
|
||||
image_status: str
|
||||
audio_status: str
|
||||
last_error: str | None
|
||||
@@ -55,6 +59,37 @@ def has_narrative_content(story: StoryLike) -> bool:
|
||||
return bool(story.story_text) or bool(story.pages)
|
||||
|
||||
|
||||
def _has_retryable_image(story: StoryLike, image_status: StoryAssetStatus) -> bool:
|
||||
if image_status in {StoryAssetStatus.READY, StoryAssetStatus.GENERATING}:
|
||||
return False
|
||||
|
||||
pages = story.pages or []
|
||||
has_missing_page_image = any(
|
||||
isinstance(page, dict)
|
||||
and page.get("image_prompt")
|
||||
and not page.get("image_url")
|
||||
for page in pages
|
||||
)
|
||||
return bool(story.cover_prompt and not story.image_url) or has_missing_page_image
|
||||
|
||||
|
||||
def _has_pending_assets(
|
||||
story: StoryLike,
|
||||
*,
|
||||
image_status: StoryAssetStatus,
|
||||
audio_status: StoryAssetStatus,
|
||||
) -> bool:
|
||||
"""Whether readable content still has optional assets to complete."""
|
||||
|
||||
if _has_retryable_image(story, image_status):
|
||||
return True
|
||||
|
||||
return bool(story.story_text) and audio_status not in {
|
||||
StoryAssetStatus.READY,
|
||||
StoryAssetStatus.GENERATING,
|
||||
}
|
||||
|
||||
|
||||
def resolve_story_generation_status(story: StoryLike) -> StoryGenerationStatus:
|
||||
"""Derive the overall status from narrative and asset states."""
|
||||
|
||||
@@ -70,6 +105,9 @@ def resolve_story_generation_status(story: StoryLike) -> StoryGenerationStatus:
|
||||
if StoryAssetStatus.FAILED in (image_status, audio_status):
|
||||
return StoryGenerationStatus.DEGRADED_COMPLETED
|
||||
|
||||
if _has_pending_assets(story, image_status=image_status, audio_status=audio_status):
|
||||
return StoryGenerationStatus.PARTIAL_READY
|
||||
|
||||
if (
|
||||
image_status == StoryAssetStatus.NOT_REQUESTED
|
||||
and audio_status == StoryAssetStatus.NOT_REQUESTED
|
||||
@@ -105,6 +143,12 @@ def sync_story_status(
|
||||
if last_error is not _ERROR_UNSET:
|
||||
story.last_error = last_error
|
||||
|
||||
story.text_status = (
|
||||
StoryAssetStatus.READY.value
|
||||
if has_narrative_content(story)
|
||||
else StoryAssetStatus.FAILED.value
|
||||
)
|
||||
|
||||
generation_status = resolve_story_generation_status(story)
|
||||
story.generation_status = generation_status.value
|
||||
|
||||
|
||||
Reference in New Issue
Block a user