feat: add provider analytics summary

This commit is contained in:
2026-04-18 22:01:34 +08:00
parent e99a7fbe14
commit 4d54c144a8
15 changed files with 437 additions and 36 deletions

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
from typing import Any
from fastapi import HTTPException
from sqlalchemy import desc, select
from sqlalchemy import desc, distinct, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models import GenerationJob, GenerationJobEvent, Story
@@ -272,28 +272,8 @@ def _as_float(value: Any) -> float | None:
return None
async def get_story_provider_stats(
db: AsyncSession,
*,
story_id: int,
user_id: str,
) -> dict[str, Any]:
"""Aggregate provider call telemetry from all user-owned jobs for one story."""
events = (
await db.execute(
select(GenerationJobEvent)
.join(GenerationJob, GenerationJobEvent.job_id == GenerationJob.id)
.where(
GenerationJob.story_id == story_id,
GenerationJob.user_id == user_id,
GenerationJobEvent.event_type.in_(
["provider_call_succeeded", "provider_call_failed"]
),
)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
def _aggregate_provider_events(events: list[GenerationJobEvent]) -> dict[str, Any]:
"""Aggregate provider telemetry from provider call events."""
by_key: dict[tuple[str, str], dict[str, Any]] = {}
total_latency = 0.0
@@ -363,7 +343,6 @@ async def get_story_provider_stats(
)
return {
"story_id": story_id,
"total_calls": successful_calls + failed_calls,
"successful_calls": successful_calls,
"failed_calls": failed_calls,
@@ -371,3 +350,66 @@ async def get_story_provider_stats(
"estimated_cost_usd": round(total_cost, 6),
"by_provider": by_provider,
}
async def get_story_provider_stats(
db: AsyncSession,
*,
story_id: int,
user_id: str,
) -> dict[str, Any]:
"""Aggregate provider call telemetry from all user-owned jobs for one story."""
events = (
await db.execute(
select(GenerationJobEvent)
.join(GenerationJob, GenerationJobEvent.job_id == GenerationJob.id)
.where(
GenerationJob.story_id == story_id,
GenerationJob.user_id == user_id,
GenerationJobEvent.event_type.in_(
["provider_call_succeeded", "provider_call_failed"]
),
)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
return {"story_id": story_id, **_aggregate_provider_events(events)}
async def get_user_provider_analytics(
db: AsyncSession,
*,
user_id: str,
) -> dict[str, Any]:
"""Aggregate provider telemetry across all stories owned by one user."""
events = (
await db.execute(
select(GenerationJobEvent)
.join(GenerationJob, GenerationJobEvent.job_id == GenerationJob.id)
.where(
GenerationJob.user_id == user_id,
GenerationJobEvent.event_type.in_(
["provider_call_succeeded", "provider_call_failed"]
),
)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
job_count, story_count = (
await db.execute(
select(
func.count(GenerationJob.id),
func.count(distinct(GenerationJob.story_id)),
).where(GenerationJob.user_id == user_id)
)
).one()
return {
**_aggregate_provider_events(events),
"job_count": job_count,
"story_count": story_count,
}