feat: add provider analytics summary

This commit is contained in:
2026-04-18 22:01:34 +08:00
parent e99a7fbe14
commit 4d54c144a8
15 changed files with 437 additions and 36 deletions

View File

@@ -19,6 +19,7 @@ from app.schemas.story_schemas import (
GenerateRequest,
GenerationJobDetailResponse,
GenerationJobSummaryResponse,
GenerationProviderAnalyticsResponse,
GenerationProviderStatsResponse,
GenerationRequest,
GenerationResponse,
@@ -34,6 +35,7 @@ from app.services import story_service
from app.services.generation_jobs import (
get_generation_job_detail,
get_story_provider_stats,
get_user_provider_analytics,
list_story_generation_jobs,
)
from app.services.memory_service import build_enhanced_memory_context
@@ -83,6 +85,18 @@ async def get_generation_job(
return await get_generation_job_detail(db, job_id=job_id, user_id=user.id)
@router.get(
"/generations/provider-analytics",
response_model=GenerationProviderAnalyticsResponse,
)
async def get_generation_provider_analytics(
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get provider call stats aggregated across the user's generation history."""
return await get_user_provider_analytics(db, user_id=user.id)
@router.get(
"/generations/{story_id}/jobs",
response_model=list[GenerationJobSummaryResponse],

View File

@@ -222,6 +222,19 @@ class GenerationProviderStatsResponse(BaseModel):
by_provider: list[GenerationProviderStatResponse] = Field(default_factory=list)
class GenerationProviderAnalyticsResponse(BaseModel):
"""Provider call stats aggregated across one user's generation history."""
total_calls: int
successful_calls: int
failed_calls: int
avg_latency_ms: float | None = None
estimated_cost_usd: float = 0.0
job_count: int
story_count: int
by_provider: list[GenerationProviderStatResponse] = Field(default_factory=list)
class AchievementItem(BaseModel):
"""Achievement item returned for a story."""

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
from typing import Any
from fastapi import HTTPException
from sqlalchemy import desc, select
from sqlalchemy import desc, distinct, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models import GenerationJob, GenerationJobEvent, Story
@@ -272,28 +272,8 @@ def _as_float(value: Any) -> float | None:
return None
async def get_story_provider_stats(
db: AsyncSession,
*,
story_id: int,
user_id: str,
) -> dict[str, Any]:
"""Aggregate provider call telemetry from all user-owned jobs for one story."""
events = (
await db.execute(
select(GenerationJobEvent)
.join(GenerationJob, GenerationJobEvent.job_id == GenerationJob.id)
.where(
GenerationJob.story_id == story_id,
GenerationJob.user_id == user_id,
GenerationJobEvent.event_type.in_(
["provider_call_succeeded", "provider_call_failed"]
),
)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
def _aggregate_provider_events(events: list[GenerationJobEvent]) -> dict[str, Any]:
"""Aggregate provider telemetry from provider call events."""
by_key: dict[tuple[str, str], dict[str, Any]] = {}
total_latency = 0.0
@@ -363,7 +343,6 @@ async def get_story_provider_stats(
)
return {
"story_id": story_id,
"total_calls": successful_calls + failed_calls,
"successful_calls": successful_calls,
"failed_calls": failed_calls,
@@ -371,3 +350,66 @@ async def get_story_provider_stats(
"estimated_cost_usd": round(total_cost, 6),
"by_provider": by_provider,
}
async def get_story_provider_stats(
db: AsyncSession,
*,
story_id: int,
user_id: str,
) -> dict[str, Any]:
"""Aggregate provider call telemetry from all user-owned jobs for one story."""
events = (
await db.execute(
select(GenerationJobEvent)
.join(GenerationJob, GenerationJobEvent.job_id == GenerationJob.id)
.where(
GenerationJob.story_id == story_id,
GenerationJob.user_id == user_id,
GenerationJobEvent.event_type.in_(
["provider_call_succeeded", "provider_call_failed"]
),
)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
return {"story_id": story_id, **_aggregate_provider_events(events)}
async def get_user_provider_analytics(
db: AsyncSession,
*,
user_id: str,
) -> dict[str, Any]:
"""Aggregate provider telemetry across all stories owned by one user."""
events = (
await db.execute(
select(GenerationJobEvent)
.join(GenerationJob, GenerationJobEvent.job_id == GenerationJob.id)
.where(
GenerationJob.user_id == user_id,
GenerationJobEvent.event_type.in_(
["provider_call_succeeded", "provider_call_failed"]
),
)
.order_by(GenerationJobEvent.id)
)
).scalars().all()
job_count, story_count = (
await db.execute(
select(
func.count(GenerationJob.id),
func.count(distinct(GenerationJob.story_id)),
).where(GenerationJob.user_id == user_id)
)
).one()
return {
**_aggregate_provider_events(events),
"job_count": job_count,
"story_count": story_count,
}

View File

@@ -431,3 +431,123 @@ async def test_story_provider_stats_aggregate_job_events(
]
finally:
app.dependency_overrides.clear()
async def test_user_provider_analytics_aggregate_across_stories(
db_session,
auth_token,
degraded_story_with_text,
test_story,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
image_job = await create_generation_job(
db_session,
user_id=degraded_story_with_text.user_id,
output_mode="asset_retry",
input_type="image",
request_payload={"assets": ["image"]},
story_id=degraded_story_with_text.id,
)
await record_generation_event(
db_session,
job=image_job,
story_id=degraded_story_with_text.id,
event_type="provider_call_succeeded",
status="succeeded",
metadata={
"capability": "image",
"adapter": "demo",
"strategy": "priority",
"latency_ms": 42,
"estimated_cost_usd": 0.01,
},
)
await record_generation_event(
db_session,
job=image_job,
story_id=degraded_story_with_text.id,
event_type="provider_call_failed",
status="failed",
metadata={
"capability": "image",
"adapter": "cqtai",
"strategy": "priority",
"latency_ms": 120,
"error": "timeout",
},
)
audio_job = await create_generation_job(
db_session,
user_id=test_story.user_id,
output_mode="asset_retry",
input_type="audio",
request_payload={"assets": ["audio"]},
story_id=test_story.id,
)
await record_generation_event(
db_session,
job=audio_job,
story_id=test_story.id,
event_type="provider_call_succeeded",
status="succeeded",
metadata={
"capability": "tts",
"adapter": "edge_tts",
"strategy": "priority",
"latency_ms": 18,
"estimated_cost_usd": 0.003,
},
)
transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.get("/api/generations/provider-analytics")
assert response.status_code == 200
data = response.json()
assert data["job_count"] == 2
assert data["story_count"] == 2
assert data["total_calls"] == 3
assert data["successful_calls"] == 2
assert data["failed_calls"] == 1
assert data["avg_latency_ms"] == 60.0
assert data["estimated_cost_usd"] == 0.013
assert data["by_provider"] == [
{
"capability": "image",
"adapter": "cqtai",
"call_count": 1,
"success_count": 0,
"failure_count": 1,
"avg_latency_ms": 120.0,
"estimated_cost_usd": 0.0,
},
{
"capability": "image",
"adapter": "demo",
"call_count": 1,
"success_count": 1,
"failure_count": 0,
"avg_latency_ms": 42.0,
"estimated_cost_usd": 0.01,
},
{
"capability": "tts",
"adapter": "edge_tts",
"call_count": 1,
"success_count": 1,
"failure_count": 0,
"avg_latency_ms": 18.0,
"estimated_cost_usd": 0.003,
},
]
finally:
app.dependency_overrides.clear()