feat: improve generation analytics and maintenance

This commit is contained in:
2026-04-19 09:03:40 +08:00
parent d5a173aa0d
commit 5318de670f
21 changed files with 1155 additions and 57 deletions

View File

@@ -4,7 +4,7 @@ import json
import uuid
from typing import AsyncGenerator
from fastapi import APIRouter, Depends, Response
from fastapi import APIRouter, Depends, Query, Response
from sqlalchemy.ext.asyncio import AsyncSession
from sse_starlette.sse import EventSourceResponse
@@ -19,6 +19,7 @@ from app.schemas.story_schemas import (
GenerateRequest,
GenerationJobDetailResponse,
GenerationJobSummaryResponse,
GenerationOpsSummaryResponse,
GenerationProviderAnalyticsResponse,
GenerationProviderStatsResponse,
GenerationRequest,
@@ -36,6 +37,7 @@ from app.services import story_service
from app.services.generation_jobs import (
get_generation_job_detail,
get_story_provider_stats,
get_user_generation_ops_summary,
get_user_provider_analytics,
list_story_generation_jobs,
)
@@ -86,16 +88,36 @@ async def get_generation_job(
return await get_generation_job_detail(db, job_id=job_id, user_id=user.id)
@router.get(
"/generations/ops-summary",
response_model=GenerationOpsSummaryResponse,
)
async def get_generation_ops_summary(
hours: int = Query(default=24, ge=1, le=168),
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get a compact recent operations summary for generation workflows."""
return await get_user_generation_ops_summary(db, user_id=user.id, hours=hours)
@router.get(
"/generations/provider-analytics",
response_model=GenerationProviderAnalyticsResponse,
)
async def get_generation_provider_analytics(
days: int | None = Query(default=None, ge=1, le=365),
capability: str | None = Query(default=None),
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get provider call stats aggregated across the user's generation history."""
return await get_user_provider_analytics(db, user_id=user.id)
return await get_user_provider_analytics(
db,
user_id=user.id,
days=days,
capability=capability,
)
@router.get(
@@ -117,11 +139,19 @@ async def list_generation_jobs(
)
async def get_generation_provider_stats(
story_id: int,
days: int | None = Query(default=None, ge=1, le=365),
capability: str | None = Query(default=None),
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get provider call stats aggregated from generation job events."""
return await get_story_provider_stats(db, story_id=story_id, user_id=user.id)
return await get_story_provider_stats(
db,
story_id=story_id,
user_id=user.id,
days=days,
capability=capability,
)
@router.get("/generations/{story_id}", response_model=StoryDetailResponse)

View File

@@ -49,6 +49,14 @@ celery_app.conf.update(
"task": "app.tasks.memory.prune_memories_task",
"schedule": crontab(minute="0", hour="3"), # Daily at 03:00
},
"prune_story_audio_cache": {
"task": "app.tasks.audio_cache.prune_story_audio_cache_task",
"schedule": crontab(minute="30", hour="3"), # Daily at 03:30
},
"prune_stale_generation_jobs": {
"task": "app.tasks.generation_maintenance.prune_stale_generation_jobs_task",
"schedule": crontab(minute="*/30"),
},
},
)

View File

@@ -62,12 +62,20 @@ class Settings(BaseSettings):
False,
description="Enable local deterministic demo providers for portfolio demos",
)
story_audio_cache_dir: str = Field(
"storage/audio",
description="Directory for cached story audio files",
)
# Celery (Redis)
story_audio_cache_dir: str = Field(
"storage/audio",
description="Directory for cached story audio files",
)
story_audio_cache_ttl_days: int = Field(
30,
description="TTL in days before cached story audio is pruned",
)
generation_job_stale_minutes: int = Field(
60,
description="Minutes before a running generation job is considered stale",
)
# Celery (Redis)
celery_broker_url: str = Field("redis://localhost:6379/0")
celery_result_backend: str = Field("redis://localhost:6379/0")

View File

@@ -220,21 +220,33 @@ class GenerationProviderStatResponse(BaseModel):
estimated_cost_usd: float = 0.0
class GenerationProviderFailureReasonResponse(BaseModel):
"""Aggregated failed provider call reason."""
reason: str
count: int
class GenerationProviderStatsResponse(BaseModel):
"""Provider call stats aggregated from generation job events."""
story_id: int
window_days: int | None = None
capability: str | None = None
total_calls: int
successful_calls: int
failed_calls: int
avg_latency_ms: float | None = None
estimated_cost_usd: float = 0.0
by_provider: list[GenerationProviderStatResponse] = Field(default_factory=list)
failure_reasons: list[GenerationProviderFailureReasonResponse] = Field(default_factory=list)
class GenerationProviderAnalyticsResponse(BaseModel):
"""Provider call stats aggregated across one user's generation history."""
window_days: int | None = None
capability: str | None = None
total_calls: int
successful_calls: int
failed_calls: int
@@ -243,6 +255,33 @@ class GenerationProviderAnalyticsResponse(BaseModel):
job_count: int
story_count: int
by_provider: list[GenerationProviderStatResponse] = Field(default_factory=list)
failure_reasons: list[GenerationProviderFailureReasonResponse] = Field(default_factory=list)
class GenerationRecentFailureResponse(BaseModel):
"""One recent failed generation task for operations summary."""
job_id: str
story_id: int | None = None
story_title: str | None = None
output_mode: str
current_step: str
error_message: str | None = None
failure_label: str
updated_at: datetime
class GenerationOpsSummaryResponse(BaseModel):
"""Recent generation health summary for one user."""
window_hours: int
stale_threshold_minutes: int
active_jobs: int
stale_running_jobs: int
failed_jobs: int
degraded_jobs: int
asset_retry_jobs: int
recent_failures: list[GenerationRecentFailureResponse] = Field(default_factory=list)
class AchievementItem(BaseModel):

View File

@@ -2,14 +2,19 @@
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from typing import Any
from fastapi import HTTPException
from sqlalchemy import desc, distinct, func, select
from sqlalchemy import desc, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
from app.db.models import GenerationJob, GenerationJobEvent, Story
logger = get_logger(__name__)
def _story_snapshot(story: Story | None) -> dict[str, Any]:
if story is None:
@@ -68,6 +73,7 @@ def _job_progress(job: GenerationJob) -> dict[str, Any]:
"asset_generation_completed": (100, "资源已完成"),
"asset_retry_completed": (100, "资源重试完成"),
"generation_completed": (100, "生成完成"),
"generation_stale_failed": (100, "任务超时已收敛"),
}
percent, label = progress_map.get(job.current_step, (10, "生成处理中"))
return {
@@ -77,6 +83,27 @@ def _job_progress(job: GenerationJob) -> dict[str, Any]:
}
def _normalize_datetime(value: datetime) -> datetime:
if value.tzinfo is None:
return value.replace(tzinfo=timezone.utc)
return value.astimezone(timezone.utc)
def _is_stale_job(job: GenerationJob, *, stale_after_minutes: int) -> bool:
cutoff = datetime.now(timezone.utc) - timedelta(minutes=stale_after_minutes)
return job.status == "running" and _normalize_datetime(job.updated_at) <= cutoff
def _failure_label(job: GenerationJob) -> str:
if job.current_step == "generation_stale_failed":
return "任务超时"
if job.output_mode == "asset_retry":
return "资源重试失败"
if job.output_mode == "asset_generation":
return "资源生成失败"
return "生成失败"
async def create_generation_job(
db: AsyncSession,
*,
@@ -266,16 +293,64 @@ async def list_story_generation_jobs(
return [generation_job_to_summary(job) for job in jobs]
async def get_active_story_generation_job(
db: AsyncSession,
*,
story_id: int,
user_id: str,
) -> GenerationJob | None:
"""Return the most recent running job for a story, if any."""
result = await db.execute(
select(GenerationJob)
.where(
GenerationJob.story_id == story_id,
GenerationJob.user_id == user_id,
GenerationJob.status == "running",
)
.order_by(desc(GenerationJob.updated_at), desc(GenerationJob.id))
.limit(1)
)
return result.scalar_one_or_none()
async def ensure_no_active_story_generation_job(
db: AsyncSession,
*,
story_id: int,
user_id: str,
) -> None:
"""Prevent duplicate asset work while a story already has a running job."""
active_job = await get_active_story_generation_job(db, story_id=story_id, user_id=user_id)
if active_job is None:
return
progress = _job_progress(active_job)
raise HTTPException(
status_code=409,
detail=(
f"当前故事已有运行中的任务({progress['progress_label']}"
"请等待当前任务完成后再试。"
),
)
def _as_float(value: Any) -> float | None:
if isinstance(value, int | float):
return float(value)
return None
def _aggregate_provider_events(events: list[GenerationJobEvent]) -> dict[str, Any]:
def _aggregate_provider_events(
events: list[GenerationJobEvent],
*,
capability: str | None = None,
) -> dict[str, Any]:
"""Aggregate provider telemetry from provider call events."""
by_key: dict[tuple[str, str], dict[str, Any]] = {}
failure_reasons: dict[str, int] = {}
total_latency = 0.0
latency_count = 0
total_cost = 0.0
@@ -284,13 +359,16 @@ def _aggregate_provider_events(events: list[GenerationJobEvent]) -> dict[str, An
for event in events:
metadata = event.event_metadata or {}
capability = str(metadata.get("capability") or "unknown")
event_capability = str(metadata.get("capability") or "unknown")
if capability is not None and event_capability != capability:
continue
adapter = str(metadata.get("adapter") or "unknown")
key = (capability, adapter)
key = (event_capability, adapter)
bucket = by_key.setdefault(
key,
{
"capability": capability,
"capability": event_capability,
"adapter": adapter,
"call_count": 0,
"success_count": 0,
@@ -318,6 +396,8 @@ def _aggregate_provider_events(events: list[GenerationJobEvent]) -> dict[str, An
else:
bucket["failure_count"] += 1
failed_calls += 1
reason = str(metadata.get("error") or "unknown_error")
failure_reasons[reason] = failure_reasons.get(reason, 0) + 1
by_provider = []
for bucket in by_key.values():
@@ -349,67 +429,243 @@ def _aggregate_provider_events(events: list[GenerationJobEvent]) -> dict[str, An
"avg_latency_ms": round(total_latency / latency_count, 2) if latency_count else None,
"estimated_cost_usd": round(total_cost, 6),
"by_provider": by_provider,
"failure_reasons": [
{"reason": reason, "count": count}
for reason, count in sorted(
failure_reasons.items(),
key=lambda item: (-item[1], item[0]),
)
],
}
def _provider_events_query(
*,
user_id: str,
story_id: int | None = None,
days: int | None = None,
):
query = (
select(GenerationJobEvent)
.join(GenerationJob, GenerationJobEvent.job_id == GenerationJob.id)
.where(
GenerationJob.user_id == user_id,
GenerationJobEvent.event_type.in_(
["provider_call_succeeded", "provider_call_failed"]
),
)
)
if story_id is not None:
query = query.where(GenerationJob.story_id == story_id)
if days is not None:
cutoff = datetime.now(timezone.utc) - timedelta(days=days)
query = query.where(GenerationJobEvent.created_at >= cutoff)
return query.order_by(GenerationJobEvent.id)
async def get_story_provider_stats(
db: AsyncSession,
*,
story_id: int,
user_id: str,
days: int | None = None,
capability: str | None = None,
) -> dict[str, Any]:
"""Aggregate provider call telemetry from all user-owned jobs for one story."""
events = (
await db.execute(
select(GenerationJobEvent)
.join(GenerationJob, GenerationJobEvent.job_id == GenerationJob.id)
.where(
GenerationJob.story_id == story_id,
GenerationJob.user_id == user_id,
GenerationJobEvent.event_type.in_(
["provider_call_succeeded", "provider_call_failed"]
),
_provider_events_query(
user_id=user_id,
story_id=story_id,
days=days,
)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
return {"story_id": story_id, **_aggregate_provider_events(events)}
return {
"story_id": story_id,
"window_days": days,
"capability": capability,
**_aggregate_provider_events(events, capability=capability),
}
async def get_user_provider_analytics(
db: AsyncSession,
*,
user_id: str,
days: int | None = None,
capability: str | None = None,
) -> dict[str, Any]:
"""Aggregate provider telemetry across all stories owned by one user."""
events = (
await db.execute(
select(GenerationJobEvent)
.join(GenerationJob, GenerationJobEvent.job_id == GenerationJob.id)
_provider_events_query(
user_id=user_id,
days=days,
)
)
).scalars().all()
filtered_event_job_ids = {
event.job_id
for event in events
if capability is None
or str((event.event_metadata or {}).get("capability") or "unknown") == capability
}
filtered_story_ids = {
event.story_id
for event in events
if event.story_id is not None
and (
capability is None
or str((event.event_metadata or {}).get("capability") or "unknown") == capability
)
}
return {
"window_days": days,
"capability": capability,
**_aggregate_provider_events(events, capability=capability),
"job_count": len(filtered_event_job_ids),
"story_count": len(filtered_story_ids),
}
async def get_user_generation_ops_summary(
db: AsyncSession,
*,
user_id: str,
hours: int = 24,
recent_failure_limit: int = 5,
) -> dict[str, Any]:
"""Summarize recent generation health for one user."""
stale_after_minutes = settings.generation_job_stale_minutes
recent_cutoff = datetime.now(timezone.utc) - timedelta(hours=hours)
running_jobs = (
await db.execute(
select(GenerationJob)
.where(
GenerationJob.user_id == user_id,
GenerationJobEvent.event_type.in_(
["provider_call_succeeded", "provider_call_failed"]
),
GenerationJob.status == "running",
)
.order_by(GenerationJobEvent.id)
.order_by(desc(GenerationJob.updated_at), desc(GenerationJob.id))
)
).scalars().all()
job_count, story_count = (
recent_jobs = (
await db.execute(
select(
func.count(GenerationJob.id),
func.count(distinct(GenerationJob.story_id)),
).where(GenerationJob.user_id == user_id)
select(GenerationJob, Story.title)
.outerjoin(Story, Story.id == GenerationJob.story_id)
.where(
GenerationJob.user_id == user_id,
GenerationJob.updated_at >= recent_cutoff,
)
.order_by(desc(GenerationJob.updated_at), desc(GenerationJob.id))
)
).one()
).all()
recent_failures: list[dict[str, Any]] = []
failed_jobs = 0
degraded_jobs = 0
asset_retry_jobs = 0
for job, story_title in recent_jobs:
if job.status == "failed":
failed_jobs += 1
if len(recent_failures) < recent_failure_limit:
recent_failures.append(
{
"job_id": job.id,
"story_id": job.story_id,
"story_title": story_title,
"output_mode": job.output_mode,
"current_step": job.current_step,
"error_message": job.error_message,
"failure_label": _failure_label(job),
"updated_at": job.updated_at,
}
)
elif job.status == "degraded_completed":
degraded_jobs += 1
if job.output_mode in {"asset_retry", "asset_generation"}:
asset_retry_jobs += 1
return {
**_aggregate_provider_events(events),
"job_count": job_count,
"story_count": story_count,
"window_hours": hours,
"stale_threshold_minutes": stale_after_minutes,
"active_jobs": len(running_jobs),
"stale_running_jobs": sum(
1 for job in running_jobs if _is_stale_job(job, stale_after_minutes=stale_after_minutes)
),
"failed_jobs": failed_jobs,
"degraded_jobs": degraded_jobs,
"asset_retry_jobs": asset_retry_jobs,
"recent_failures": recent_failures,
}
async def mark_stale_generation_jobs(
db: AsyncSession,
*,
stale_after_minutes: int | None = None,
) -> dict[str, int]:
"""Mark long-running generation jobs as failed so they no longer appear stuck forever."""
threshold = stale_after_minutes or settings.generation_job_stale_minutes
running_jobs = (
await db.execute(
select(GenerationJob)
.where(GenerationJob.status == "running")
.order_by(GenerationJob.updated_at, GenerationJob.id)
)
).scalars().all()
marked_stale = 0
for job in running_jobs:
if not _is_stale_job(job, stale_after_minutes=threshold):
continue
story = None
if job.story_id is not None:
story = (
await db.execute(
select(Story).where(
Story.id == job.story_id,
Story.user_id == job.user_id,
)
)
).scalar_one_or_none()
await finish_generation_job(
db,
job=job,
story=story,
status="failed",
current_step="generation_stale_failed",
error_message=f"Generation job exceeded {threshold} minutes without progress.",
message="Generation job was marked failed after exceeding the stale threshold.",
metadata={"stale_after_minutes": threshold},
)
marked_stale += 1
logger.warning(
"generation_job_marked_stale",
job_id=job.id,
story_id=job.story_id,
output_mode=job.output_mode,
stale_after_minutes=threshold,
)
return {
"running": len(running_jobs),
"marked_stale": marked_stale,
"stale_after_minutes": threshold,
}

View File

@@ -2,6 +2,7 @@
import asyncio
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Literal
from fastapi import HTTPException
@@ -9,6 +10,7 @@ from sqlalchemy import desc, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.core.config import settings
from app.core.logging import get_logger
from app.db.models import ChildProfile, Story, StoryUniverse
from app.schemas.story_schemas import (
@@ -32,6 +34,7 @@ from app.services.audio_storage import (
)
from app.services.generation_jobs import (
create_generation_job,
ensure_no_active_story_generation_job,
finish_generation_job,
record_generation_event,
)
@@ -1369,6 +1372,7 @@ async def retry_story_assets(
db: AsyncSession,
) -> Story:
"""Retry selected assets through one workflow-level endpoint."""
await ensure_no_active_story_generation_job(db, story_id=story_id, user_id=user_id)
requested_assets = list(dict.fromkeys(assets))
job = await create_generation_job(
db,
@@ -1443,6 +1447,7 @@ async def generate_story_cover(
db: AsyncSession,
) -> str:
"""Generate cover image for an existing story."""
await ensure_no_active_story_generation_job(db, story_id=story_id, user_id=user_id)
job = await create_generation_job(
db,
user_id=user_id,
@@ -1495,6 +1500,7 @@ async def generate_story_audio(
db: AsyncSession,
) -> bytes:
"""Generate audio for a story."""
await ensure_no_active_story_generation_job(db, story_id=story_id, user_id=user_id)
job = await create_generation_job(
db,
user_id=user_id,
@@ -1597,6 +1603,50 @@ async def clear_story_audio_cache(
return await get_story_audio_status(story_id, user_id, db)
async def prune_story_audio_cache(db: AsyncSession) -> dict[str, int]:
"""Prune expired audio cache files and repair story metadata."""
ttl_days = max(1, settings.story_audio_cache_ttl_days)
cutoff = datetime.now(timezone.utc) - timedelta(days=ttl_days)
result = await db.execute(select(Story).where(Story.audio_path.is_not(None)))
stories = result.scalars().all()
scanned = 0
pruned = 0
repaired = 0
for story in stories:
scanned += 1
metadata = get_audio_cache_metadata(story.audio_path)
if not metadata.exists:
story.audio_path = None
if story.audio_status == StoryAssetStatus.READY.value:
sync_story_status(story, audio_status=StoryAssetStatus.NOT_REQUESTED)
repaired += 1
continue
if metadata.updated_at and metadata.updated_at < cutoff:
delete_audio_cache(story.audio_path)
story.audio_path = None
sync_story_status(
story,
audio_status=StoryAssetStatus.NOT_REQUESTED,
last_error=None,
)
pruned += 1
await db.commit()
logger.info(
"story_audio_cache_pruned",
scanned=scanned,
pruned=pruned,
repaired=repaired,
ttl_days=ttl_days,
)
return {"scanned": scanned, "pruned": pruned, "repaired": repaired}
async def get_story_achievements(
story_id: int,
user_id: str,

View File

@@ -0,0 +1,29 @@
"""Celery tasks for story audio cache maintenance."""
import asyncio
from app.core.celery_app import celery_app
from app.core.logging import get_logger
from app.db.database import _get_session_factory
from app.services.story_service import prune_story_audio_cache
logger = get_logger(__name__)
@celery_app.task
def prune_story_audio_cache_task():
"""Daily task to prune expired story audio cache files."""
logger.info("prune_story_audio_cache_task_started")
async def _run():
session_factory = _get_session_factory()
async with session_factory() as session:
return await prune_story_audio_cache(session)
try:
result = asyncio.run(_run())
logger.info("prune_story_audio_cache_task_completed", **result)
return result
except Exception as exc:
logger.error("prune_story_audio_cache_task_failed", error=str(exc))
raise

View File

@@ -0,0 +1,30 @@
"""Generation job maintenance tasks."""
import asyncio
from app.core.celery_app import celery_app
from app.core.logging import get_logger
from app.db.database import _get_session_factory
from app.services.generation_jobs import mark_stale_generation_jobs
logger = get_logger(__name__)
@celery_app.task
def prune_stale_generation_jobs_task():
"""Periodically mark stale running generation jobs as failed."""
logger.info("prune_stale_generation_jobs_task_started")
async def _run():
session_factory = _get_session_factory()
async with session_factory() as session:
return await mark_stale_generation_jobs(session)
try:
result = asyncio.run(_run())
logger.info("prune_stale_generation_jobs_task_completed", **result)
return result
except Exception as exc:
logger.error("prune_stale_generation_jobs_task_failed", error=str(exc))
raise