Implement unified story generation flow
This commit is contained in:
408
backend/app/services/admin_provider_analytics.py
Normal file
408
backend/app/services/admin_provider_analytics.py
Normal file
@@ -0,0 +1,408 @@
|
||||
"""Admin-facing provider analytics across generation and voice telemetry."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.admin_models import CostRecord
|
||||
from app.db.models import VoiceSession, VoiceSessionEvent, VoiceTurn
|
||||
from app.services.generation_jobs import (
|
||||
_aggregate_provider_events,
|
||||
_as_float,
|
||||
_event_matches_capability,
|
||||
_provider_events_query,
|
||||
)
|
||||
|
||||
|
||||
def _empty_admin_user_bucket(user_id: str) -> dict[str, Any]:
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"call_count": 0,
|
||||
"success_count": 0,
|
||||
"failure_count": 0,
|
||||
"estimated_cost_usd": 0.0,
|
||||
"job_ids": set(),
|
||||
"story_ids": set(),
|
||||
}
|
||||
|
||||
|
||||
def _merge_admin_user_bucket(
|
||||
target: dict[str, Any],
|
||||
source: dict[str, Any],
|
||||
) -> None:
|
||||
target["call_count"] += int(source["call_count"])
|
||||
target["success_count"] += int(source["success_count"])
|
||||
target["failure_count"] += int(source["failure_count"])
|
||||
target["estimated_cost_usd"] += float(source["estimated_cost_usd"])
|
||||
target["job_ids"].update(source["job_ids"])
|
||||
target["story_ids"].update(source["story_ids"])
|
||||
|
||||
|
||||
def _serialize_admin_user_buckets(
|
||||
by_user: dict[str, dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
serialized_users = [
|
||||
{
|
||||
"user_id": user_id,
|
||||
"call_count": bucket["call_count"],
|
||||
"success_count": bucket["success_count"],
|
||||
"failure_count": bucket["failure_count"],
|
||||
"job_count": len(bucket["job_ids"]),
|
||||
"story_count": len(bucket["story_ids"]),
|
||||
"estimated_cost_usd": round(bucket["estimated_cost_usd"], 6),
|
||||
}
|
||||
for user_id, bucket in by_user.items()
|
||||
]
|
||||
serialized_users.sort(
|
||||
key=lambda item: (
|
||||
-int(item["call_count"]),
|
||||
-float(item["estimated_cost_usd"]),
|
||||
str(item["user_id"]),
|
||||
)
|
||||
)
|
||||
return serialized_users
|
||||
|
||||
|
||||
def _merge_provider_analytics(
|
||||
left: dict[str, Any],
|
||||
right: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
provider_buckets: dict[tuple[str, str], dict[str, Any]] = {}
|
||||
latency_totals: dict[tuple[str, str], float] = {}
|
||||
latency_counts: dict[tuple[str, str], int] = {}
|
||||
failure_reasons: dict[str, int] = {}
|
||||
|
||||
for payload in (left, right):
|
||||
for row in payload["by_provider"]:
|
||||
capability_name = str(row["capability"])
|
||||
adapter_name = str(row["adapter"])
|
||||
key = (capability_name, adapter_name)
|
||||
bucket = provider_buckets.setdefault(
|
||||
key,
|
||||
{
|
||||
"capability": capability_name,
|
||||
"adapter": adapter_name,
|
||||
"call_count": 0,
|
||||
"success_count": 0,
|
||||
"failure_count": 0,
|
||||
"estimated_cost_usd": 0.0,
|
||||
},
|
||||
)
|
||||
call_count = int(row["call_count"])
|
||||
bucket["call_count"] += call_count
|
||||
bucket["success_count"] += int(row["success_count"])
|
||||
bucket["failure_count"] += int(row["failure_count"])
|
||||
bucket["estimated_cost_usd"] += float(row["estimated_cost_usd"])
|
||||
|
||||
if row["avg_latency_ms"] is not None and call_count:
|
||||
latency_totals[key] = latency_totals.get(key, 0.0) + (
|
||||
float(row["avg_latency_ms"]) * call_count
|
||||
)
|
||||
latency_counts[key] = latency_counts.get(key, 0) + call_count
|
||||
|
||||
for item in payload["failure_reasons"]:
|
||||
reason = str(item["reason"])
|
||||
failure_reasons[reason] = failure_reasons.get(reason, 0) + int(item["count"])
|
||||
|
||||
by_provider = []
|
||||
total_latency = 0.0
|
||||
latency_count = 0
|
||||
for key, bucket in provider_buckets.items():
|
||||
bucket_latency_count = latency_counts.get(key, 0)
|
||||
bucket_latency_total = latency_totals.get(key, 0.0)
|
||||
if bucket_latency_count:
|
||||
total_latency += bucket_latency_total
|
||||
latency_count += bucket_latency_count
|
||||
by_provider.append(
|
||||
{
|
||||
**bucket,
|
||||
"avg_latency_ms": (
|
||||
round(bucket_latency_total / bucket_latency_count, 2)
|
||||
if bucket_latency_count
|
||||
else None
|
||||
),
|
||||
"estimated_cost_usd": round(bucket["estimated_cost_usd"], 6),
|
||||
}
|
||||
)
|
||||
|
||||
by_provider.sort(
|
||||
key=lambda item: (
|
||||
str(item["capability"]),
|
||||
str(item["adapter"]),
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
"total_calls": int(left["total_calls"]) + int(right["total_calls"]),
|
||||
"successful_calls": int(left["successful_calls"]) + int(right["successful_calls"]),
|
||||
"failed_calls": int(left["failed_calls"]) + int(right["failed_calls"]),
|
||||
"avg_latency_ms": round(total_latency / latency_count, 2) if latency_count else None,
|
||||
"estimated_cost_usd": round(
|
||||
float(left["estimated_cost_usd"]) + float(right["estimated_cost_usd"]),
|
||||
6,
|
||||
),
|
||||
"by_provider": by_provider,
|
||||
"failure_reasons": [
|
||||
{"reason": reason, "count": count}
|
||||
for reason, count in sorted(
|
||||
failure_reasons.items(),
|
||||
key=lambda item: (-item[1], item[0]),
|
||||
)
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _voice_asr_provider_from_turn(turn: VoiceTurn) -> str:
|
||||
story_patch = turn.story_patch or {}
|
||||
return str(story_patch.get("transcription_provider") or "unknown")
|
||||
|
||||
|
||||
async def _aggregate_voice_asr_provider_analytics(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
days: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Aggregate ASR telemetry from voice co-creation sessions."""
|
||||
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=days) if days is not None else None
|
||||
|
||||
turn_query = (
|
||||
select(
|
||||
VoiceTurn,
|
||||
VoiceSession.user_id,
|
||||
VoiceSession.final_story_id,
|
||||
VoiceSession.id,
|
||||
)
|
||||
.join(VoiceSession, VoiceTurn.session_id == VoiceSession.id)
|
||||
.where(
|
||||
VoiceTurn.user_audio_path.isnot(None),
|
||||
VoiceTurn.user_transcript.isnot(None),
|
||||
)
|
||||
)
|
||||
failure_query = (
|
||||
select(
|
||||
VoiceSessionEvent,
|
||||
VoiceSession.user_id,
|
||||
VoiceSession.final_story_id,
|
||||
VoiceSession.id,
|
||||
)
|
||||
.join(VoiceSession, VoiceSessionEvent.session_id == VoiceSession.id)
|
||||
.where(VoiceSessionEvent.event_type == "turn_transcription_failed")
|
||||
)
|
||||
cost_query = select(
|
||||
CostRecord.user_id,
|
||||
CostRecord.provider_name,
|
||||
CostRecord.estimated_cost,
|
||||
).where(CostRecord.capability == "asr")
|
||||
|
||||
if cutoff is not None:
|
||||
turn_query = turn_query.where(VoiceTurn.created_at >= cutoff)
|
||||
failure_query = failure_query.where(VoiceSessionEvent.created_at >= cutoff)
|
||||
cost_query = cost_query.where(CostRecord.timestamp >= cutoff)
|
||||
|
||||
turn_rows = (await db.execute(turn_query)).all()
|
||||
failure_rows = (await db.execute(failure_query)).all()
|
||||
cost_rows = (await db.execute(cost_query)).all()
|
||||
|
||||
costs_by_provider: dict[str, float] = {}
|
||||
costs_by_user: dict[str, float] = {}
|
||||
for user_id, provider_name, estimated_cost in cost_rows:
|
||||
cost = float(estimated_cost or 0.0)
|
||||
provider = str(provider_name or "unknown")
|
||||
costs_by_provider[provider] = costs_by_provider.get(provider, 0.0) + cost
|
||||
costs_by_user[str(user_id)] = costs_by_user.get(str(user_id), 0.0) + cost
|
||||
|
||||
provider_buckets: dict[tuple[str, str], dict[str, Any]] = {}
|
||||
failure_reasons: dict[str, int] = {}
|
||||
by_user: dict[str, dict[str, Any]] = {}
|
||||
user_ids: set[str] = set()
|
||||
story_ids: set[int] = set()
|
||||
voice_session_ids: set[str] = set()
|
||||
successful_calls = 0
|
||||
failed_calls = 0
|
||||
|
||||
def provider_bucket(adapter: str) -> dict[str, Any]:
|
||||
return provider_buckets.setdefault(
|
||||
("asr", adapter),
|
||||
{
|
||||
"capability": "asr",
|
||||
"adapter": adapter,
|
||||
"call_count": 0,
|
||||
"success_count": 0,
|
||||
"failure_count": 0,
|
||||
"avg_latency_ms": None,
|
||||
"estimated_cost_usd": 0.0,
|
||||
},
|
||||
)
|
||||
|
||||
for turn, user_id, final_story_id, session_id in turn_rows:
|
||||
user_id = str(user_id)
|
||||
adapter = _voice_asr_provider_from_turn(turn)
|
||||
user_ids.add(user_id)
|
||||
voice_session_ids.add(str(session_id))
|
||||
if final_story_id is not None:
|
||||
story_ids.add(int(final_story_id))
|
||||
|
||||
bucket = provider_bucket(adapter)
|
||||
bucket["call_count"] += 1
|
||||
bucket["success_count"] += 1
|
||||
successful_calls += 1
|
||||
|
||||
user_bucket = by_user.setdefault(user_id, _empty_admin_user_bucket(user_id))
|
||||
user_bucket["call_count"] += 1
|
||||
user_bucket["success_count"] += 1
|
||||
if final_story_id is not None:
|
||||
user_bucket["story_ids"].add(int(final_story_id))
|
||||
|
||||
for provider_name, cost in costs_by_provider.items():
|
||||
key = ("asr", provider_name)
|
||||
if key in provider_buckets:
|
||||
provider_buckets[key]["estimated_cost_usd"] += cost
|
||||
|
||||
for user_id, cost in costs_by_user.items():
|
||||
if user_id in by_user:
|
||||
by_user[user_id]["estimated_cost_usd"] += cost
|
||||
|
||||
for event, user_id, final_story_id, session_id in failure_rows:
|
||||
metadata = event.event_metadata or {}
|
||||
adapter = str(
|
||||
metadata.get("adapter")
|
||||
or metadata.get("transcription_provider")
|
||||
or "unknown"
|
||||
)
|
||||
user_id = str(user_id)
|
||||
reason = str(metadata.get("error") or "unknown_error")
|
||||
user_ids.add(user_id)
|
||||
voice_session_ids.add(str(session_id))
|
||||
if final_story_id is not None:
|
||||
story_ids.add(int(final_story_id))
|
||||
|
||||
bucket = provider_bucket(adapter)
|
||||
bucket["call_count"] += 1
|
||||
bucket["failure_count"] += 1
|
||||
failed_calls += 1
|
||||
failure_reasons[reason] = failure_reasons.get(reason, 0) + 1
|
||||
|
||||
user_bucket = by_user.setdefault(user_id, _empty_admin_user_bucket(user_id))
|
||||
user_bucket["call_count"] += 1
|
||||
user_bucket["failure_count"] += 1
|
||||
if final_story_id is not None:
|
||||
user_bucket["story_ids"].add(int(final_story_id))
|
||||
|
||||
by_provider = [
|
||||
{
|
||||
**bucket,
|
||||
"estimated_cost_usd": round(bucket["estimated_cost_usd"], 6),
|
||||
}
|
||||
for bucket in provider_buckets.values()
|
||||
]
|
||||
by_provider.sort(
|
||||
key=lambda item: (
|
||||
str(item["capability"]),
|
||||
str(item["adapter"]),
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
"total_calls": successful_calls + failed_calls,
|
||||
"successful_calls": successful_calls,
|
||||
"failed_calls": failed_calls,
|
||||
"avg_latency_ms": None,
|
||||
"estimated_cost_usd": round(
|
||||
sum(float(bucket["estimated_cost_usd"]) for bucket in provider_buckets.values()),
|
||||
6,
|
||||
),
|
||||
"by_provider": by_provider,
|
||||
"failure_reasons": [
|
||||
{"reason": reason, "count": count}
|
||||
for reason, count in sorted(
|
||||
failure_reasons.items(),
|
||||
key=lambda item: (-item[1], item[0]),
|
||||
)
|
||||
],
|
||||
"by_user": by_user,
|
||||
"user_ids": user_ids,
|
||||
"story_ids": story_ids,
|
||||
"voice_session_ids": voice_session_ids,
|
||||
"voice_turn_count": successful_calls,
|
||||
}
|
||||
|
||||
|
||||
async def get_admin_provider_analytics(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
days: int | None = None,
|
||||
capability: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Aggregate provider telemetry across every user in the current environment."""
|
||||
|
||||
rows = (await db.execute(_provider_events_query(days=days))).all()
|
||||
events = [event for event, _, _ in rows]
|
||||
filtered_rows = [
|
||||
(event, user_id, story_id)
|
||||
for event, user_id, story_id in rows
|
||||
if _event_matches_capability(event, capability)
|
||||
]
|
||||
|
||||
by_user: dict[str, dict[str, Any]] = {}
|
||||
filtered_job_ids = {event.job_id for event, _, _ in filtered_rows}
|
||||
filtered_story_ids = {
|
||||
story_id for _, _, story_id in filtered_rows if story_id is not None
|
||||
}
|
||||
filtered_user_ids = {user_id for _, user_id, _ in filtered_rows}
|
||||
|
||||
for event, user_id, story_id in filtered_rows:
|
||||
bucket = by_user.setdefault(
|
||||
user_id,
|
||||
_empty_admin_user_bucket(user_id),
|
||||
)
|
||||
bucket["call_count"] += 1
|
||||
bucket["job_ids"].add(event.job_id)
|
||||
if story_id is not None:
|
||||
bucket["story_ids"].add(story_id)
|
||||
|
||||
if event.event_type == "provider_call_succeeded":
|
||||
bucket["success_count"] += 1
|
||||
bucket["estimated_cost_usd"] += (
|
||||
_as_float((event.event_metadata or {}).get("estimated_cost_usd")) or 0.0
|
||||
)
|
||||
else:
|
||||
bucket["failure_count"] += 1
|
||||
|
||||
provider_analytics = _aggregate_provider_events(events, capability=capability)
|
||||
voice_session_count = 0
|
||||
voice_turn_count = 0
|
||||
if capability in {None, "asr"}:
|
||||
asr_analytics = await _aggregate_voice_asr_provider_analytics(db, days=days)
|
||||
provider_analytics = _merge_provider_analytics(
|
||||
provider_analytics,
|
||||
asr_analytics,
|
||||
)
|
||||
filtered_user_ids.update(asr_analytics["user_ids"])
|
||||
filtered_story_ids.update(asr_analytics["story_ids"])
|
||||
voice_session_count = len(asr_analytics["voice_session_ids"])
|
||||
voice_turn_count = int(asr_analytics["voice_turn_count"])
|
||||
|
||||
for user_id, source_bucket in asr_analytics["by_user"].items():
|
||||
target_bucket = by_user.setdefault(
|
||||
user_id,
|
||||
_empty_admin_user_bucket(user_id),
|
||||
)
|
||||
_merge_admin_user_bucket(target_bucket, source_bucket)
|
||||
|
||||
return {
|
||||
"scope": "current_environment",
|
||||
"window_days": days,
|
||||
"capability": capability,
|
||||
**provider_analytics,
|
||||
"user_count": len(filtered_user_ids),
|
||||
"job_count": len(filtered_job_ids),
|
||||
"story_count": len(filtered_story_ids),
|
||||
"voice_session_count": voice_session_count,
|
||||
"voice_turn_count": voice_turn_count,
|
||||
"by_user": _serialize_admin_user_buckets(by_user),
|
||||
}
|
||||
Reference in New Issue
Block a user