374 lines
11 KiB
Python
374 lines
11 KiB
Python
"""Lightweight generation job/event tracking."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
from fastapi import HTTPException
|
|
from sqlalchemy import desc, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.db.models import GenerationJob, GenerationJobEvent, Story
|
|
|
|
|
|
def _story_snapshot(story: Story | None) -> dict[str, Any]:
|
|
if story is None:
|
|
return {}
|
|
|
|
return {
|
|
"story_id": story.id,
|
|
"mode": story.mode,
|
|
"generation_status": story.generation_status,
|
|
"text_status": story.text_status,
|
|
"image_status": story.image_status,
|
|
"audio_status": story.audio_status,
|
|
"retryable_assets": story.retryable_assets,
|
|
"last_error": story.last_error,
|
|
}
|
|
|
|
|
|
def _job_status_from_story(story: Story) -> str:
|
|
if story.generation_status == "failed":
|
|
return "failed"
|
|
if story.generation_status == "degraded_completed":
|
|
return "degraded_completed"
|
|
return "completed"
|
|
|
|
|
|
def _job_progress(job: GenerationJob) -> dict[str, Any]:
|
|
"""Resolve a compact progress summary for polling-oriented clients."""
|
|
|
|
if job.status == "failed":
|
|
return {
|
|
"progress_percent": 100,
|
|
"progress_label": "生成失败",
|
|
"is_terminal": True,
|
|
}
|
|
|
|
if job.status in {"completed", "degraded_completed"}:
|
|
return {
|
|
"progress_percent": 100,
|
|
"progress_label": "已完成" if job.status == "completed" else "降级完成",
|
|
"is_terminal": True,
|
|
}
|
|
|
|
progress_map: dict[str, tuple[int, str]] = {
|
|
"request_accepted": (5, "已接收请求"),
|
|
"context_prepared": (20, "上下文已准备"),
|
|
"narrative_generated": (45, "正文已生成"),
|
|
"story_saved": (60, "主记录已保存"),
|
|
"provider_call_started": (65, "Provider 调用中"),
|
|
"provider_call_succeeded": (72, "Provider 调用成功"),
|
|
"provider_call_failed": (72, "Provider 调用失败,尝试恢复"),
|
|
"cover_image_started": (75, "封面生成中"),
|
|
"storybook_images_started": (75, "绘本插图生成中"),
|
|
"audio_started": (75, "音频生成中"),
|
|
"asset_retry_started": (25, "资源重试中"),
|
|
"postprocessing_queued": (90, "后处理已排队"),
|
|
"asset_generation_completed": (100, "资源已完成"),
|
|
"asset_retry_completed": (100, "资源重试完成"),
|
|
"generation_completed": (100, "生成完成"),
|
|
}
|
|
percent, label = progress_map.get(job.current_step, (10, "生成处理中"))
|
|
return {
|
|
"progress_percent": percent,
|
|
"progress_label": label,
|
|
"is_terminal": percent >= 100,
|
|
}
|
|
|
|
|
|
async def create_generation_job(
|
|
db: AsyncSession,
|
|
*,
|
|
user_id: str,
|
|
output_mode: str,
|
|
input_type: str,
|
|
request_payload: dict[str, Any],
|
|
story_id: int | None = None,
|
|
) -> GenerationJob:
|
|
"""Create a generation job and record its first event."""
|
|
|
|
job = GenerationJob(
|
|
user_id=user_id,
|
|
story_id=story_id,
|
|
output_mode=output_mode,
|
|
input_type=input_type,
|
|
status="running",
|
|
current_step="request_accepted",
|
|
request_payload=request_payload,
|
|
result_snapshot={},
|
|
)
|
|
db.add(job)
|
|
await db.flush()
|
|
await record_generation_event(
|
|
db,
|
|
job=job,
|
|
story_id=story_id,
|
|
event_type="request_accepted",
|
|
status="succeeded",
|
|
message="Generation request accepted.",
|
|
metadata={"output_mode": output_mode, "input_type": input_type},
|
|
commit=False,
|
|
)
|
|
await db.commit()
|
|
await db.refresh(job)
|
|
return job
|
|
|
|
|
|
async def record_generation_event(
|
|
db: AsyncSession,
|
|
*,
|
|
job: GenerationJob,
|
|
event_type: str,
|
|
status: str,
|
|
story_id: int | None = None,
|
|
message: str | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
commit: bool = True,
|
|
) -> GenerationJobEvent:
|
|
"""Append one event to an existing generation job."""
|
|
|
|
event = GenerationJobEvent(
|
|
job_id=job.id,
|
|
story_id=story_id if story_id is not None else job.story_id,
|
|
event_type=event_type,
|
|
status=status,
|
|
message=message,
|
|
event_metadata=metadata or {},
|
|
)
|
|
db.add(event)
|
|
if commit:
|
|
await db.commit()
|
|
return event
|
|
|
|
|
|
async def finish_generation_job(
|
|
db: AsyncSession,
|
|
*,
|
|
job: GenerationJob,
|
|
story: Story | None,
|
|
status: str | None = None,
|
|
current_step: str,
|
|
error_message: str | None = None,
|
|
message: str | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> GenerationJob:
|
|
"""Mark a generation job as completed/degraded/failed and append a final event."""
|
|
|
|
job.story_id = story.id if story is not None else job.story_id
|
|
job.status = status or (_job_status_from_story(story) if story is not None else "failed")
|
|
job.current_step = current_step
|
|
job.error_message = error_message
|
|
job.result_snapshot = _story_snapshot(story)
|
|
await record_generation_event(
|
|
db,
|
|
job=job,
|
|
story_id=job.story_id,
|
|
event_type=current_step,
|
|
status=job.status,
|
|
message=message,
|
|
metadata={
|
|
**(metadata or {}),
|
|
"result_snapshot": job.result_snapshot,
|
|
},
|
|
commit=False,
|
|
)
|
|
await db.commit()
|
|
await db.refresh(job)
|
|
return job
|
|
|
|
|
|
def generation_event_to_response(event: GenerationJobEvent) -> dict[str, Any]:
|
|
"""Convert a generation event ORM object to an API response dict."""
|
|
|
|
return {
|
|
"id": event.id,
|
|
"job_id": event.job_id,
|
|
"story_id": event.story_id,
|
|
"event_type": event.event_type,
|
|
"status": event.status,
|
|
"message": event.message,
|
|
"event_metadata": event.event_metadata or {},
|
|
"created_at": event.created_at,
|
|
}
|
|
|
|
|
|
def generation_job_to_summary(job: GenerationJob) -> dict[str, Any]:
|
|
"""Convert a generation job ORM object to an API summary dict."""
|
|
|
|
progress = _job_progress(job)
|
|
return {
|
|
"id": job.id,
|
|
"story_id": job.story_id,
|
|
"output_mode": job.output_mode,
|
|
"input_type": job.input_type,
|
|
"status": job.status,
|
|
"current_step": job.current_step,
|
|
**progress,
|
|
"result_snapshot": job.result_snapshot or {},
|
|
"error_message": job.error_message,
|
|
"created_at": job.created_at,
|
|
"updated_at": job.updated_at,
|
|
}
|
|
|
|
|
|
async def get_generation_job_detail(
|
|
db: AsyncSession,
|
|
*,
|
|
job_id: str,
|
|
user_id: str,
|
|
) -> dict[str, Any]:
|
|
"""Return a user-owned generation job with its ordered event stream."""
|
|
|
|
result = await db.execute(
|
|
select(GenerationJob).where(
|
|
GenerationJob.id == job_id,
|
|
GenerationJob.user_id == user_id,
|
|
)
|
|
)
|
|
job = result.scalar_one_or_none()
|
|
if job is None:
|
|
raise HTTPException(status_code=404, detail="Generation job not found")
|
|
|
|
events = (
|
|
await db.execute(
|
|
select(GenerationJobEvent)
|
|
.where(GenerationJobEvent.job_id == job.id)
|
|
.order_by(GenerationJobEvent.id)
|
|
)
|
|
).scalars().all()
|
|
|
|
return {
|
|
**generation_job_to_summary(job),
|
|
"request_payload": job.request_payload or {},
|
|
"events": [generation_event_to_response(event) for event in events],
|
|
}
|
|
|
|
|
|
async def list_story_generation_jobs(
|
|
db: AsyncSession,
|
|
*,
|
|
story_id: int,
|
|
user_id: str,
|
|
) -> list[dict[str, Any]]:
|
|
"""Return recent generation jobs for a user-owned story."""
|
|
|
|
jobs = (
|
|
await db.execute(
|
|
select(GenerationJob)
|
|
.where(
|
|
GenerationJob.story_id == story_id,
|
|
GenerationJob.user_id == user_id,
|
|
)
|
|
.order_by(desc(GenerationJob.created_at), desc(GenerationJob.id))
|
|
)
|
|
).scalars().all()
|
|
return [generation_job_to_summary(job) for job in jobs]
|
|
|
|
|
|
def _as_float(value: Any) -> float | None:
|
|
if isinstance(value, int | float):
|
|
return float(value)
|
|
return None
|
|
|
|
|
|
async def get_story_provider_stats(
|
|
db: AsyncSession,
|
|
*,
|
|
story_id: int,
|
|
user_id: str,
|
|
) -> dict[str, Any]:
|
|
"""Aggregate provider call telemetry from all user-owned jobs for one story."""
|
|
|
|
events = (
|
|
await db.execute(
|
|
select(GenerationJobEvent)
|
|
.join(GenerationJob, GenerationJobEvent.job_id == GenerationJob.id)
|
|
.where(
|
|
GenerationJob.story_id == story_id,
|
|
GenerationJob.user_id == user_id,
|
|
GenerationJobEvent.event_type.in_(
|
|
["provider_call_succeeded", "provider_call_failed"]
|
|
),
|
|
)
|
|
.order_by(GenerationJobEvent.id)
|
|
)
|
|
).scalars().all()
|
|
|
|
by_key: dict[tuple[str, str], dict[str, Any]] = {}
|
|
total_latency = 0.0
|
|
latency_count = 0
|
|
total_cost = 0.0
|
|
successful_calls = 0
|
|
failed_calls = 0
|
|
|
|
for event in events:
|
|
metadata = event.event_metadata or {}
|
|
capability = str(metadata.get("capability") or "unknown")
|
|
adapter = str(metadata.get("adapter") or "unknown")
|
|
key = (capability, adapter)
|
|
bucket = by_key.setdefault(
|
|
key,
|
|
{
|
|
"capability": capability,
|
|
"adapter": adapter,
|
|
"call_count": 0,
|
|
"success_count": 0,
|
|
"failure_count": 0,
|
|
"latency_total": 0.0,
|
|
"latency_count": 0,
|
|
"estimated_cost_usd": 0.0,
|
|
},
|
|
)
|
|
|
|
bucket["call_count"] += 1
|
|
latency = _as_float(metadata.get("latency_ms"))
|
|
if latency is not None:
|
|
bucket["latency_total"] += latency
|
|
bucket["latency_count"] += 1
|
|
total_latency += latency
|
|
latency_count += 1
|
|
|
|
if event.event_type == "provider_call_succeeded":
|
|
bucket["success_count"] += 1
|
|
successful_calls += 1
|
|
cost = _as_float(metadata.get("estimated_cost_usd")) or 0.0
|
|
bucket["estimated_cost_usd"] += cost
|
|
total_cost += cost
|
|
else:
|
|
bucket["failure_count"] += 1
|
|
failed_calls += 1
|
|
|
|
by_provider = []
|
|
for bucket in by_key.values():
|
|
bucket_latency_count = bucket.pop("latency_count")
|
|
bucket_latency_total = bucket.pop("latency_total")
|
|
by_provider.append(
|
|
{
|
|
**bucket,
|
|
"avg_latency_ms": (
|
|
round(bucket_latency_total / bucket_latency_count, 2)
|
|
if bucket_latency_count
|
|
else None
|
|
),
|
|
"estimated_cost_usd": round(bucket["estimated_cost_usd"], 6),
|
|
}
|
|
)
|
|
|
|
by_provider.sort(
|
|
key=lambda item: (
|
|
str(item["capability"]),
|
|
str(item["adapter"]),
|
|
)
|
|
)
|
|
|
|
return {
|
|
"story_id": story_id,
|
|
"total_calls": successful_calls + failed_calls,
|
|
"successful_calls": successful_calls,
|
|
"failed_calls": failed_calls,
|
|
"avg_latency_ms": round(total_latency / latency_count, 2) if latency_count else None,
|
|
"estimated_cost_usd": round(total_cost, 6),
|
|
"by_provider": by_provider,
|
|
}
|