feat: improve generation analytics and maintenance

This commit is contained in:
2026-04-19 09:03:40 +08:00
parent d5a173aa0d
commit 5318de670f
21 changed files with 1155 additions and 57 deletions

View File

@@ -2,14 +2,19 @@
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from typing import Any
from fastapi import HTTPException
from sqlalchemy import desc, distinct, func, select
from sqlalchemy import desc, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
from app.db.models import GenerationJob, GenerationJobEvent, Story
logger = get_logger(__name__)
def _story_snapshot(story: Story | None) -> dict[str, Any]:
if story is None:
@@ -68,6 +73,7 @@ def _job_progress(job: GenerationJob) -> dict[str, Any]:
"asset_generation_completed": (100, "资源已完成"),
"asset_retry_completed": (100, "资源重试完成"),
"generation_completed": (100, "生成完成"),
"generation_stale_failed": (100, "任务超时已收敛"),
}
percent, label = progress_map.get(job.current_step, (10, "生成处理中"))
return {
@@ -77,6 +83,27 @@ def _job_progress(job: GenerationJob) -> dict[str, Any]:
}
def _normalize_datetime(value: datetime) -> datetime:
if value.tzinfo is None:
return value.replace(tzinfo=timezone.utc)
return value.astimezone(timezone.utc)
def _is_stale_job(job: GenerationJob, *, stale_after_minutes: int) -> bool:
cutoff = datetime.now(timezone.utc) - timedelta(minutes=stale_after_minutes)
return job.status == "running" and _normalize_datetime(job.updated_at) <= cutoff
def _failure_label(job: GenerationJob) -> str:
if job.current_step == "generation_stale_failed":
return "任务超时"
if job.output_mode == "asset_retry":
return "资源重试失败"
if job.output_mode == "asset_generation":
return "资源生成失败"
return "生成失败"
async def create_generation_job(
db: AsyncSession,
*,
@@ -266,16 +293,64 @@ async def list_story_generation_jobs(
return [generation_job_to_summary(job) for job in jobs]
async def get_active_story_generation_job(
db: AsyncSession,
*,
story_id: int,
user_id: str,
) -> GenerationJob | None:
"""Return the most recent running job for a story, if any."""
result = await db.execute(
select(GenerationJob)
.where(
GenerationJob.story_id == story_id,
GenerationJob.user_id == user_id,
GenerationJob.status == "running",
)
.order_by(desc(GenerationJob.updated_at), desc(GenerationJob.id))
.limit(1)
)
return result.scalar_one_or_none()
async def ensure_no_active_story_generation_job(
db: AsyncSession,
*,
story_id: int,
user_id: str,
) -> None:
"""Prevent duplicate asset work while a story already has a running job."""
active_job = await get_active_story_generation_job(db, story_id=story_id, user_id=user_id)
if active_job is None:
return
progress = _job_progress(active_job)
raise HTTPException(
status_code=409,
detail=(
f"当前故事已有运行中的任务({progress['progress_label']}"
"请等待当前任务完成后再试。"
),
)
def _as_float(value: Any) -> float | None:
if isinstance(value, int | float):
return float(value)
return None
def _aggregate_provider_events(events: list[GenerationJobEvent]) -> dict[str, Any]:
def _aggregate_provider_events(
events: list[GenerationJobEvent],
*,
capability: str | None = None,
) -> dict[str, Any]:
"""Aggregate provider telemetry from provider call events."""
by_key: dict[tuple[str, str], dict[str, Any]] = {}
failure_reasons: dict[str, int] = {}
total_latency = 0.0
latency_count = 0
total_cost = 0.0
@@ -284,13 +359,16 @@ def _aggregate_provider_events(events: list[GenerationJobEvent]) -> dict[str, An
for event in events:
metadata = event.event_metadata or {}
capability = str(metadata.get("capability") or "unknown")
event_capability = str(metadata.get("capability") or "unknown")
if capability is not None and event_capability != capability:
continue
adapter = str(metadata.get("adapter") or "unknown")
key = (capability, adapter)
key = (event_capability, adapter)
bucket = by_key.setdefault(
key,
{
"capability": capability,
"capability": event_capability,
"adapter": adapter,
"call_count": 0,
"success_count": 0,
@@ -318,6 +396,8 @@ def _aggregate_provider_events(events: list[GenerationJobEvent]) -> dict[str, An
else:
bucket["failure_count"] += 1
failed_calls += 1
reason = str(metadata.get("error") or "unknown_error")
failure_reasons[reason] = failure_reasons.get(reason, 0) + 1
by_provider = []
for bucket in by_key.values():
@@ -349,67 +429,243 @@ def _aggregate_provider_events(events: list[GenerationJobEvent]) -> dict[str, An
"avg_latency_ms": round(total_latency / latency_count, 2) if latency_count else None,
"estimated_cost_usd": round(total_cost, 6),
"by_provider": by_provider,
"failure_reasons": [
{"reason": reason, "count": count}
for reason, count in sorted(
failure_reasons.items(),
key=lambda item: (-item[1], item[0]),
)
],
}
def _provider_events_query(
*,
user_id: str,
story_id: int | None = None,
days: int | None = None,
):
query = (
select(GenerationJobEvent)
.join(GenerationJob, GenerationJobEvent.job_id == GenerationJob.id)
.where(
GenerationJob.user_id == user_id,
GenerationJobEvent.event_type.in_(
["provider_call_succeeded", "provider_call_failed"]
),
)
)
if story_id is not None:
query = query.where(GenerationJob.story_id == story_id)
if days is not None:
cutoff = datetime.now(timezone.utc) - timedelta(days=days)
query = query.where(GenerationJobEvent.created_at >= cutoff)
return query.order_by(GenerationJobEvent.id)
async def get_story_provider_stats(
db: AsyncSession,
*,
story_id: int,
user_id: str,
days: int | None = None,
capability: str | None = None,
) -> dict[str, Any]:
"""Aggregate provider call telemetry from all user-owned jobs for one story."""
events = (
await db.execute(
select(GenerationJobEvent)
.join(GenerationJob, GenerationJobEvent.job_id == GenerationJob.id)
.where(
GenerationJob.story_id == story_id,
GenerationJob.user_id == user_id,
GenerationJobEvent.event_type.in_(
["provider_call_succeeded", "provider_call_failed"]
),
_provider_events_query(
user_id=user_id,
story_id=story_id,
days=days,
)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
return {"story_id": story_id, **_aggregate_provider_events(events)}
return {
"story_id": story_id,
"window_days": days,
"capability": capability,
**_aggregate_provider_events(events, capability=capability),
}
async def get_user_provider_analytics(
db: AsyncSession,
*,
user_id: str,
days: int | None = None,
capability: str | None = None,
) -> dict[str, Any]:
"""Aggregate provider telemetry across all stories owned by one user."""
events = (
await db.execute(
select(GenerationJobEvent)
.join(GenerationJob, GenerationJobEvent.job_id == GenerationJob.id)
_provider_events_query(
user_id=user_id,
days=days,
)
)
).scalars().all()
filtered_event_job_ids = {
event.job_id
for event in events
if capability is None
or str((event.event_metadata or {}).get("capability") or "unknown") == capability
}
filtered_story_ids = {
event.story_id
for event in events
if event.story_id is not None
and (
capability is None
or str((event.event_metadata or {}).get("capability") or "unknown") == capability
)
}
return {
"window_days": days,
"capability": capability,
**_aggregate_provider_events(events, capability=capability),
"job_count": len(filtered_event_job_ids),
"story_count": len(filtered_story_ids),
}
async def get_user_generation_ops_summary(
db: AsyncSession,
*,
user_id: str,
hours: int = 24,
recent_failure_limit: int = 5,
) -> dict[str, Any]:
"""Summarize recent generation health for one user."""
stale_after_minutes = settings.generation_job_stale_minutes
recent_cutoff = datetime.now(timezone.utc) - timedelta(hours=hours)
running_jobs = (
await db.execute(
select(GenerationJob)
.where(
GenerationJob.user_id == user_id,
GenerationJobEvent.event_type.in_(
["provider_call_succeeded", "provider_call_failed"]
),
GenerationJob.status == "running",
)
.order_by(GenerationJobEvent.id)
.order_by(desc(GenerationJob.updated_at), desc(GenerationJob.id))
)
).scalars().all()
job_count, story_count = (
recent_jobs = (
await db.execute(
select(
func.count(GenerationJob.id),
func.count(distinct(GenerationJob.story_id)),
).where(GenerationJob.user_id == user_id)
select(GenerationJob, Story.title)
.outerjoin(Story, Story.id == GenerationJob.story_id)
.where(
GenerationJob.user_id == user_id,
GenerationJob.updated_at >= recent_cutoff,
)
.order_by(desc(GenerationJob.updated_at), desc(GenerationJob.id))
)
).one()
).all()
recent_failures: list[dict[str, Any]] = []
failed_jobs = 0
degraded_jobs = 0
asset_retry_jobs = 0
for job, story_title in recent_jobs:
if job.status == "failed":
failed_jobs += 1
if len(recent_failures) < recent_failure_limit:
recent_failures.append(
{
"job_id": job.id,
"story_id": job.story_id,
"story_title": story_title,
"output_mode": job.output_mode,
"current_step": job.current_step,
"error_message": job.error_message,
"failure_label": _failure_label(job),
"updated_at": job.updated_at,
}
)
elif job.status == "degraded_completed":
degraded_jobs += 1
if job.output_mode in {"asset_retry", "asset_generation"}:
asset_retry_jobs += 1
return {
**_aggregate_provider_events(events),
"job_count": job_count,
"story_count": story_count,
"window_hours": hours,
"stale_threshold_minutes": stale_after_minutes,
"active_jobs": len(running_jobs),
"stale_running_jobs": sum(
1 for job in running_jobs if _is_stale_job(job, stale_after_minutes=stale_after_minutes)
),
"failed_jobs": failed_jobs,
"degraded_jobs": degraded_jobs,
"asset_retry_jobs": asset_retry_jobs,
"recent_failures": recent_failures,
}
async def mark_stale_generation_jobs(
db: AsyncSession,
*,
stale_after_minutes: int | None = None,
) -> dict[str, int]:
"""Mark long-running generation jobs as failed so they no longer appear stuck forever."""
threshold = stale_after_minutes or settings.generation_job_stale_minutes
running_jobs = (
await db.execute(
select(GenerationJob)
.where(GenerationJob.status == "running")
.order_by(GenerationJob.updated_at, GenerationJob.id)
)
).scalars().all()
marked_stale = 0
for job in running_jobs:
if not _is_stale_job(job, stale_after_minutes=threshold):
continue
story = None
if job.story_id is not None:
story = (
await db.execute(
select(Story).where(
Story.id == job.story_id,
Story.user_id == job.user_id,
)
)
).scalar_one_or_none()
await finish_generation_job(
db,
job=job,
story=story,
status="failed",
current_step="generation_stale_failed",
error_message=f"Generation job exceeded {threshold} minutes without progress.",
message="Generation job was marked failed after exceeding the stale threshold.",
metadata={"stale_after_minutes": threshold},
)
marked_stale += 1
logger.warning(
"generation_job_marked_stale",
job_id=job.id,
story_id=job.story_id,
output_mode=job.output_mode,
stale_after_minutes=threshold,
)
return {
"running": len(running_jobs),
"marked_stale": marked_stale,
"stale_after_minutes": threshold,
}