"""Admin-facing provider analytics across generation and voice telemetry.""" from __future__ import annotations from datetime import datetime, timedelta, timezone from typing import Any from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.db.admin_models import CostRecord from app.db.models import VoiceSession, VoiceSessionEvent, VoiceTurn from app.services.generation_jobs import ( _aggregate_provider_events, _as_float, _event_matches_capability, _provider_events_query, ) def _empty_admin_user_bucket(user_id: str) -> dict[str, Any]: return { "user_id": user_id, "call_count": 0, "success_count": 0, "failure_count": 0, "estimated_cost_usd": 0.0, "job_ids": set(), "story_ids": set(), } def _merge_admin_user_bucket( target: dict[str, Any], source: dict[str, Any], ) -> None: target["call_count"] += int(source["call_count"]) target["success_count"] += int(source["success_count"]) target["failure_count"] += int(source["failure_count"]) target["estimated_cost_usd"] += float(source["estimated_cost_usd"]) target["job_ids"].update(source["job_ids"]) target["story_ids"].update(source["story_ids"]) def _serialize_admin_user_buckets( by_user: dict[str, dict[str, Any]], ) -> list[dict[str, Any]]: serialized_users = [ { "user_id": user_id, "call_count": bucket["call_count"], "success_count": bucket["success_count"], "failure_count": bucket["failure_count"], "job_count": len(bucket["job_ids"]), "story_count": len(bucket["story_ids"]), "estimated_cost_usd": round(bucket["estimated_cost_usd"], 6), } for user_id, bucket in by_user.items() ] serialized_users.sort( key=lambda item: ( -int(item["call_count"]), -float(item["estimated_cost_usd"]), str(item["user_id"]), ) ) return serialized_users def _merge_provider_analytics( left: dict[str, Any], right: dict[str, Any], ) -> dict[str, Any]: provider_buckets: dict[tuple[str, str], dict[str, Any]] = {} latency_totals: dict[tuple[str, str], float] = {} latency_counts: dict[tuple[str, str], int] = {} failure_reasons: dict[str, int] = {} for payload in (left, right): for row in payload["by_provider"]: capability_name = str(row["capability"]) adapter_name = str(row["adapter"]) key = (capability_name, adapter_name) bucket = provider_buckets.setdefault( key, { "capability": capability_name, "adapter": adapter_name, "call_count": 0, "success_count": 0, "failure_count": 0, "estimated_cost_usd": 0.0, }, ) call_count = int(row["call_count"]) bucket["call_count"] += call_count bucket["success_count"] += int(row["success_count"]) bucket["failure_count"] += int(row["failure_count"]) bucket["estimated_cost_usd"] += float(row["estimated_cost_usd"]) if row["avg_latency_ms"] is not None and call_count: latency_totals[key] = latency_totals.get(key, 0.0) + ( float(row["avg_latency_ms"]) * call_count ) latency_counts[key] = latency_counts.get(key, 0) + call_count for item in payload["failure_reasons"]: reason = str(item["reason"]) failure_reasons[reason] = failure_reasons.get(reason, 0) + int(item["count"]) by_provider = [] total_latency = 0.0 latency_count = 0 for key, bucket in provider_buckets.items(): bucket_latency_count = latency_counts.get(key, 0) bucket_latency_total = latency_totals.get(key, 0.0) if bucket_latency_count: total_latency += bucket_latency_total latency_count += bucket_latency_count by_provider.append( { **bucket, "avg_latency_ms": ( round(bucket_latency_total / bucket_latency_count, 2) if bucket_latency_count else None ), "estimated_cost_usd": round(bucket["estimated_cost_usd"], 6), } ) by_provider.sort( key=lambda item: ( str(item["capability"]), str(item["adapter"]), ) ) return { "total_calls": int(left["total_calls"]) + int(right["total_calls"]), "successful_calls": int(left["successful_calls"]) + int(right["successful_calls"]), "failed_calls": int(left["failed_calls"]) + int(right["failed_calls"]), "avg_latency_ms": round(total_latency / latency_count, 2) if latency_count else None, "estimated_cost_usd": round( float(left["estimated_cost_usd"]) + float(right["estimated_cost_usd"]), 6, ), "by_provider": by_provider, "failure_reasons": [ {"reason": reason, "count": count} for reason, count in sorted( failure_reasons.items(), key=lambda item: (-item[1], item[0]), ) ], } def _voice_asr_provider_from_turn(turn: VoiceTurn) -> str: story_patch = turn.story_patch or {} return str(story_patch.get("transcription_provider") or "unknown") async def _aggregate_voice_asr_provider_analytics( db: AsyncSession, *, days: int | None = None, ) -> dict[str, Any]: """Aggregate ASR telemetry from voice co-creation sessions.""" cutoff = datetime.now(timezone.utc) - timedelta(days=days) if days is not None else None turn_query = ( select( VoiceTurn, VoiceSession.user_id, VoiceSession.final_story_id, VoiceSession.id, ) .join(VoiceSession, VoiceTurn.session_id == VoiceSession.id) .where( VoiceTurn.user_audio_path.isnot(None), VoiceTurn.user_transcript.isnot(None), ) ) failure_query = ( select( VoiceSessionEvent, VoiceSession.user_id, VoiceSession.final_story_id, VoiceSession.id, ) .join(VoiceSession, VoiceSessionEvent.session_id == VoiceSession.id) .where(VoiceSessionEvent.event_type == "turn_transcription_failed") ) cost_query = select( CostRecord.user_id, CostRecord.provider_name, CostRecord.estimated_cost, ).where(CostRecord.capability == "asr") if cutoff is not None: turn_query = turn_query.where(VoiceTurn.created_at >= cutoff) failure_query = failure_query.where(VoiceSessionEvent.created_at >= cutoff) cost_query = cost_query.where(CostRecord.timestamp >= cutoff) turn_rows = (await db.execute(turn_query)).all() failure_rows = (await db.execute(failure_query)).all() cost_rows = (await db.execute(cost_query)).all() costs_by_provider: dict[str, float] = {} costs_by_user: dict[str, float] = {} for user_id, provider_name, estimated_cost in cost_rows: cost = float(estimated_cost or 0.0) provider = str(provider_name or "unknown") costs_by_provider[provider] = costs_by_provider.get(provider, 0.0) + cost costs_by_user[str(user_id)] = costs_by_user.get(str(user_id), 0.0) + cost provider_buckets: dict[tuple[str, str], dict[str, Any]] = {} failure_reasons: dict[str, int] = {} by_user: dict[str, dict[str, Any]] = {} user_ids: set[str] = set() story_ids: set[int] = set() voice_session_ids: set[str] = set() successful_calls = 0 failed_calls = 0 def provider_bucket(adapter: str) -> dict[str, Any]: return provider_buckets.setdefault( ("asr", adapter), { "capability": "asr", "adapter": adapter, "call_count": 0, "success_count": 0, "failure_count": 0, "avg_latency_ms": None, "estimated_cost_usd": 0.0, }, ) for turn, user_id, final_story_id, session_id in turn_rows: user_id = str(user_id) adapter = _voice_asr_provider_from_turn(turn) user_ids.add(user_id) voice_session_ids.add(str(session_id)) if final_story_id is not None: story_ids.add(int(final_story_id)) bucket = provider_bucket(adapter) bucket["call_count"] += 1 bucket["success_count"] += 1 successful_calls += 1 user_bucket = by_user.setdefault(user_id, _empty_admin_user_bucket(user_id)) user_bucket["call_count"] += 1 user_bucket["success_count"] += 1 if final_story_id is not None: user_bucket["story_ids"].add(int(final_story_id)) for provider_name, cost in costs_by_provider.items(): key = ("asr", provider_name) if key in provider_buckets: provider_buckets[key]["estimated_cost_usd"] += cost for user_id, cost in costs_by_user.items(): if user_id in by_user: by_user[user_id]["estimated_cost_usd"] += cost for event, user_id, final_story_id, session_id in failure_rows: metadata = event.event_metadata or {} adapter = str( metadata.get("adapter") or metadata.get("transcription_provider") or "unknown" ) user_id = str(user_id) reason = str(metadata.get("error") or "unknown_error") user_ids.add(user_id) voice_session_ids.add(str(session_id)) if final_story_id is not None: story_ids.add(int(final_story_id)) bucket = provider_bucket(adapter) bucket["call_count"] += 1 bucket["failure_count"] += 1 failed_calls += 1 failure_reasons[reason] = failure_reasons.get(reason, 0) + 1 user_bucket = by_user.setdefault(user_id, _empty_admin_user_bucket(user_id)) user_bucket["call_count"] += 1 user_bucket["failure_count"] += 1 if final_story_id is not None: user_bucket["story_ids"].add(int(final_story_id)) by_provider = [ { **bucket, "estimated_cost_usd": round(bucket["estimated_cost_usd"], 6), } for bucket in provider_buckets.values() ] by_provider.sort( key=lambda item: ( str(item["capability"]), str(item["adapter"]), ) ) return { "total_calls": successful_calls + failed_calls, "successful_calls": successful_calls, "failed_calls": failed_calls, "avg_latency_ms": None, "estimated_cost_usd": round( sum(float(bucket["estimated_cost_usd"]) for bucket in provider_buckets.values()), 6, ), "by_provider": by_provider, "failure_reasons": [ {"reason": reason, "count": count} for reason, count in sorted( failure_reasons.items(), key=lambda item: (-item[1], item[0]), ) ], "by_user": by_user, "user_ids": user_ids, "story_ids": story_ids, "voice_session_ids": voice_session_ids, "voice_turn_count": successful_calls, } async def get_admin_provider_analytics( db: AsyncSession, *, days: int | None = None, capability: str | None = None, ) -> dict[str, Any]: """Aggregate provider telemetry across every user in the current environment.""" rows = (await db.execute(_provider_events_query(days=days))).all() events = [event for event, _, _ in rows] filtered_rows = [ (event, user_id, story_id) for event, user_id, story_id in rows if _event_matches_capability(event, capability) ] by_user: dict[str, dict[str, Any]] = {} filtered_job_ids = {event.job_id for event, _, _ in filtered_rows} filtered_story_ids = { story_id for _, _, story_id in filtered_rows if story_id is not None } filtered_user_ids = {user_id for _, user_id, _ in filtered_rows} for event, user_id, story_id in filtered_rows: bucket = by_user.setdefault( user_id, _empty_admin_user_bucket(user_id), ) bucket["call_count"] += 1 bucket["job_ids"].add(event.job_id) if story_id is not None: bucket["story_ids"].add(story_id) if event.event_type == "provider_call_succeeded": bucket["success_count"] += 1 bucket["estimated_cost_usd"] += ( _as_float((event.event_metadata or {}).get("estimated_cost_usd")) or 0.0 ) else: bucket["failure_count"] += 1 provider_analytics = _aggregate_provider_events(events, capability=capability) voice_session_count = 0 voice_turn_count = 0 if capability in {None, "asr"}: asr_analytics = await _aggregate_voice_asr_provider_analytics(db, days=days) provider_analytics = _merge_provider_analytics( provider_analytics, asr_analytics, ) filtered_user_ids.update(asr_analytics["user_ids"]) filtered_story_ids.update(asr_analytics["story_ids"]) voice_session_count = len(asr_analytics["voice_session_ids"]) voice_turn_count = int(asr_analytics["voice_turn_count"]) for user_id, source_bucket in asr_analytics["by_user"].items(): target_bucket = by_user.setdefault( user_id, _empty_admin_user_bucket(user_id), ) _merge_admin_user_bucket(target_bucket, source_bucket) return { "scope": "current_environment", "window_days": days, "capability": capability, **provider_analytics, "user_count": len(filtered_user_ids), "job_count": len(filtered_job_ids), "story_count": len(filtered_story_ids), "voice_session_count": voice_session_count, "voice_turn_count": voice_turn_count, "by_user": _serialize_admin_user_buckets(by_user), }