feat: add provider analytics summary
This commit is contained in:
@@ -19,6 +19,7 @@ from app.schemas.story_schemas import (
|
||||
GenerateRequest,
|
||||
GenerationJobDetailResponse,
|
||||
GenerationJobSummaryResponse,
|
||||
GenerationProviderAnalyticsResponse,
|
||||
GenerationProviderStatsResponse,
|
||||
GenerationRequest,
|
||||
GenerationResponse,
|
||||
@@ -34,6 +35,7 @@ from app.services import story_service
|
||||
from app.services.generation_jobs import (
|
||||
get_generation_job_detail,
|
||||
get_story_provider_stats,
|
||||
get_user_provider_analytics,
|
||||
list_story_generation_jobs,
|
||||
)
|
||||
from app.services.memory_service import build_enhanced_memory_context
|
||||
@@ -83,6 +85,18 @@ async def get_generation_job(
|
||||
return await get_generation_job_detail(db, job_id=job_id, user_id=user.id)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/generations/provider-analytics",
|
||||
response_model=GenerationProviderAnalyticsResponse,
|
||||
)
|
||||
async def get_generation_provider_analytics(
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get provider call stats aggregated across the user's generation history."""
|
||||
return await get_user_provider_analytics(db, user_id=user.id)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/generations/{story_id}/jobs",
|
||||
response_model=list[GenerationJobSummaryResponse],
|
||||
|
||||
@@ -222,6 +222,19 @@ class GenerationProviderStatsResponse(BaseModel):
|
||||
by_provider: list[GenerationProviderStatResponse] = Field(default_factory=list)
|
||||
|
||||
|
||||
class GenerationProviderAnalyticsResponse(BaseModel):
|
||||
"""Provider call stats aggregated across one user's generation history."""
|
||||
|
||||
total_calls: int
|
||||
successful_calls: int
|
||||
failed_calls: int
|
||||
avg_latency_ms: float | None = None
|
||||
estimated_cost_usd: float = 0.0
|
||||
job_count: int
|
||||
story_count: int
|
||||
by_provider: list[GenerationProviderStatResponse] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AchievementItem(BaseModel):
|
||||
"""Achievement item returned for a story."""
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import desc, select
|
||||
from sqlalchemy import desc, distinct, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models import GenerationJob, GenerationJobEvent, Story
|
||||
@@ -272,28 +272,8 @@ def _as_float(value: Any) -> float | None:
|
||||
return None
|
||||
|
||||
|
||||
async def get_story_provider_stats(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
story_id: int,
|
||||
user_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Aggregate provider call telemetry from all user-owned jobs for one story."""
|
||||
|
||||
events = (
|
||||
await db.execute(
|
||||
select(GenerationJobEvent)
|
||||
.join(GenerationJob, GenerationJobEvent.job_id == GenerationJob.id)
|
||||
.where(
|
||||
GenerationJob.story_id == story_id,
|
||||
GenerationJob.user_id == user_id,
|
||||
GenerationJobEvent.event_type.in_(
|
||||
["provider_call_succeeded", "provider_call_failed"]
|
||||
),
|
||||
)
|
||||
.order_by(GenerationJobEvent.id)
|
||||
)
|
||||
).scalars().all()
|
||||
def _aggregate_provider_events(events: list[GenerationJobEvent]) -> dict[str, Any]:
|
||||
"""Aggregate provider telemetry from provider call events."""
|
||||
|
||||
by_key: dict[tuple[str, str], dict[str, Any]] = {}
|
||||
total_latency = 0.0
|
||||
@@ -363,7 +343,6 @@ async def get_story_provider_stats(
|
||||
)
|
||||
|
||||
return {
|
||||
"story_id": story_id,
|
||||
"total_calls": successful_calls + failed_calls,
|
||||
"successful_calls": successful_calls,
|
||||
"failed_calls": failed_calls,
|
||||
@@ -371,3 +350,66 @@ async def get_story_provider_stats(
|
||||
"estimated_cost_usd": round(total_cost, 6),
|
||||
"by_provider": by_provider,
|
||||
}
|
||||
|
||||
|
||||
async def get_story_provider_stats(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
story_id: int,
|
||||
user_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Aggregate provider call telemetry from all user-owned jobs for one story."""
|
||||
|
||||
events = (
|
||||
await db.execute(
|
||||
select(GenerationJobEvent)
|
||||
.join(GenerationJob, GenerationJobEvent.job_id == GenerationJob.id)
|
||||
.where(
|
||||
GenerationJob.story_id == story_id,
|
||||
GenerationJob.user_id == user_id,
|
||||
GenerationJobEvent.event_type.in_(
|
||||
["provider_call_succeeded", "provider_call_failed"]
|
||||
),
|
||||
)
|
||||
.order_by(GenerationJobEvent.id)
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
return {"story_id": story_id, **_aggregate_provider_events(events)}
|
||||
|
||||
|
||||
async def get_user_provider_analytics(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Aggregate provider telemetry across all stories owned by one user."""
|
||||
|
||||
events = (
|
||||
await db.execute(
|
||||
select(GenerationJobEvent)
|
||||
.join(GenerationJob, GenerationJobEvent.job_id == GenerationJob.id)
|
||||
.where(
|
||||
GenerationJob.user_id == user_id,
|
||||
GenerationJobEvent.event_type.in_(
|
||||
["provider_call_succeeded", "provider_call_failed"]
|
||||
),
|
||||
)
|
||||
.order_by(GenerationJobEvent.id)
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
job_count, story_count = (
|
||||
await db.execute(
|
||||
select(
|
||||
func.count(GenerationJob.id),
|
||||
func.count(distinct(GenerationJob.story_id)),
|
||||
).where(GenerationJob.user_id == user_id)
|
||||
)
|
||||
).one()
|
||||
|
||||
return {
|
||||
**_aggregate_provider_events(events),
|
||||
"job_count": job_count,
|
||||
"story_count": story_count,
|
||||
}
|
||||
|
||||
@@ -431,3 +431,123 @@ async def test_story_provider_stats_aggregate_job_events(
|
||||
]
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
async def test_user_provider_analytics_aggregate_across_stories(
|
||||
db_session,
|
||||
auth_token,
|
||||
degraded_story_with_text,
|
||||
test_story,
|
||||
):
|
||||
async def override_get_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
image_job = await create_generation_job(
|
||||
db_session,
|
||||
user_id=degraded_story_with_text.user_id,
|
||||
output_mode="asset_retry",
|
||||
input_type="image",
|
||||
request_payload={"assets": ["image"]},
|
||||
story_id=degraded_story_with_text.id,
|
||||
)
|
||||
await record_generation_event(
|
||||
db_session,
|
||||
job=image_job,
|
||||
story_id=degraded_story_with_text.id,
|
||||
event_type="provider_call_succeeded",
|
||||
status="succeeded",
|
||||
metadata={
|
||||
"capability": "image",
|
||||
"adapter": "demo",
|
||||
"strategy": "priority",
|
||||
"latency_ms": 42,
|
||||
"estimated_cost_usd": 0.01,
|
||||
},
|
||||
)
|
||||
await record_generation_event(
|
||||
db_session,
|
||||
job=image_job,
|
||||
story_id=degraded_story_with_text.id,
|
||||
event_type="provider_call_failed",
|
||||
status="failed",
|
||||
metadata={
|
||||
"capability": "image",
|
||||
"adapter": "cqtai",
|
||||
"strategy": "priority",
|
||||
"latency_ms": 120,
|
||||
"error": "timeout",
|
||||
},
|
||||
)
|
||||
|
||||
audio_job = await create_generation_job(
|
||||
db_session,
|
||||
user_id=test_story.user_id,
|
||||
output_mode="asset_retry",
|
||||
input_type="audio",
|
||||
request_payload={"assets": ["audio"]},
|
||||
story_id=test_story.id,
|
||||
)
|
||||
await record_generation_event(
|
||||
db_session,
|
||||
job=audio_job,
|
||||
story_id=test_story.id,
|
||||
event_type="provider_call_succeeded",
|
||||
status="succeeded",
|
||||
metadata={
|
||||
"capability": "tts",
|
||||
"adapter": "edge_tts",
|
||||
"strategy": "priority",
|
||||
"latency_ms": 18,
|
||||
"estimated_cost_usd": 0.003,
|
||||
},
|
||||
)
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
try:
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
client.cookies.set("access_token", auth_token)
|
||||
|
||||
response = await client.get("/api/generations/provider-analytics")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["job_count"] == 2
|
||||
assert data["story_count"] == 2
|
||||
assert data["total_calls"] == 3
|
||||
assert data["successful_calls"] == 2
|
||||
assert data["failed_calls"] == 1
|
||||
assert data["avg_latency_ms"] == 60.0
|
||||
assert data["estimated_cost_usd"] == 0.013
|
||||
assert data["by_provider"] == [
|
||||
{
|
||||
"capability": "image",
|
||||
"adapter": "cqtai",
|
||||
"call_count": 1,
|
||||
"success_count": 0,
|
||||
"failure_count": 1,
|
||||
"avg_latency_ms": 120.0,
|
||||
"estimated_cost_usd": 0.0,
|
||||
},
|
||||
{
|
||||
"capability": "image",
|
||||
"adapter": "demo",
|
||||
"call_count": 1,
|
||||
"success_count": 1,
|
||||
"failure_count": 0,
|
||||
"avg_latency_ms": 42.0,
|
||||
"estimated_cost_usd": 0.01,
|
||||
},
|
||||
{
|
||||
"capability": "tts",
|
||||
"adapter": "edge_tts",
|
||||
"call_count": 1,
|
||||
"success_count": 1,
|
||||
"failure_count": 0,
|
||||
"avg_latency_ms": 18.0,
|
||||
"estimated_cost_usd": 0.003,
|
||||
},
|
||||
]
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
Reference in New Issue
Block a user