Implement unified story generation flow

This commit is contained in:
2026-06-18 14:48:27 +08:00
parent 0ccfd00a23
commit 7ebdfb2582
27 changed files with 1323 additions and 215 deletions

View File

@@ -0,0 +1,408 @@
"""Admin-facing provider analytics across generation and voice telemetry."""
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from typing import Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.admin_models import CostRecord
from app.db.models import VoiceSession, VoiceSessionEvent, VoiceTurn
from app.services.generation_jobs import (
_aggregate_provider_events,
_as_float,
_event_matches_capability,
_provider_events_query,
)
def _empty_admin_user_bucket(user_id: str) -> dict[str, Any]:
return {
"user_id": user_id,
"call_count": 0,
"success_count": 0,
"failure_count": 0,
"estimated_cost_usd": 0.0,
"job_ids": set(),
"story_ids": set(),
}
def _merge_admin_user_bucket(
target: dict[str, Any],
source: dict[str, Any],
) -> None:
target["call_count"] += int(source["call_count"])
target["success_count"] += int(source["success_count"])
target["failure_count"] += int(source["failure_count"])
target["estimated_cost_usd"] += float(source["estimated_cost_usd"])
target["job_ids"].update(source["job_ids"])
target["story_ids"].update(source["story_ids"])
def _serialize_admin_user_buckets(
by_user: dict[str, dict[str, Any]],
) -> list[dict[str, Any]]:
serialized_users = [
{
"user_id": user_id,
"call_count": bucket["call_count"],
"success_count": bucket["success_count"],
"failure_count": bucket["failure_count"],
"job_count": len(bucket["job_ids"]),
"story_count": len(bucket["story_ids"]),
"estimated_cost_usd": round(bucket["estimated_cost_usd"], 6),
}
for user_id, bucket in by_user.items()
]
serialized_users.sort(
key=lambda item: (
-int(item["call_count"]),
-float(item["estimated_cost_usd"]),
str(item["user_id"]),
)
)
return serialized_users
def _merge_provider_analytics(
left: dict[str, Any],
right: dict[str, Any],
) -> dict[str, Any]:
provider_buckets: dict[tuple[str, str], dict[str, Any]] = {}
latency_totals: dict[tuple[str, str], float] = {}
latency_counts: dict[tuple[str, str], int] = {}
failure_reasons: dict[str, int] = {}
for payload in (left, right):
for row in payload["by_provider"]:
capability_name = str(row["capability"])
adapter_name = str(row["adapter"])
key = (capability_name, adapter_name)
bucket = provider_buckets.setdefault(
key,
{
"capability": capability_name,
"adapter": adapter_name,
"call_count": 0,
"success_count": 0,
"failure_count": 0,
"estimated_cost_usd": 0.0,
},
)
call_count = int(row["call_count"])
bucket["call_count"] += call_count
bucket["success_count"] += int(row["success_count"])
bucket["failure_count"] += int(row["failure_count"])
bucket["estimated_cost_usd"] += float(row["estimated_cost_usd"])
if row["avg_latency_ms"] is not None and call_count:
latency_totals[key] = latency_totals.get(key, 0.0) + (
float(row["avg_latency_ms"]) * call_count
)
latency_counts[key] = latency_counts.get(key, 0) + call_count
for item in payload["failure_reasons"]:
reason = str(item["reason"])
failure_reasons[reason] = failure_reasons.get(reason, 0) + int(item["count"])
by_provider = []
total_latency = 0.0
latency_count = 0
for key, bucket in provider_buckets.items():
bucket_latency_count = latency_counts.get(key, 0)
bucket_latency_total = latency_totals.get(key, 0.0)
if bucket_latency_count:
total_latency += bucket_latency_total
latency_count += bucket_latency_count
by_provider.append(
{
**bucket,
"avg_latency_ms": (
round(bucket_latency_total / bucket_latency_count, 2)
if bucket_latency_count
else None
),
"estimated_cost_usd": round(bucket["estimated_cost_usd"], 6),
}
)
by_provider.sort(
key=lambda item: (
str(item["capability"]),
str(item["adapter"]),
)
)
return {
"total_calls": int(left["total_calls"]) + int(right["total_calls"]),
"successful_calls": int(left["successful_calls"]) + int(right["successful_calls"]),
"failed_calls": int(left["failed_calls"]) + int(right["failed_calls"]),
"avg_latency_ms": round(total_latency / latency_count, 2) if latency_count else None,
"estimated_cost_usd": round(
float(left["estimated_cost_usd"]) + float(right["estimated_cost_usd"]),
6,
),
"by_provider": by_provider,
"failure_reasons": [
{"reason": reason, "count": count}
for reason, count in sorted(
failure_reasons.items(),
key=lambda item: (-item[1], item[0]),
)
],
}
def _voice_asr_provider_from_turn(turn: VoiceTurn) -> str:
story_patch = turn.story_patch or {}
return str(story_patch.get("transcription_provider") or "unknown")
async def _aggregate_voice_asr_provider_analytics(
db: AsyncSession,
*,
days: int | None = None,
) -> dict[str, Any]:
"""Aggregate ASR telemetry from voice co-creation sessions."""
cutoff = datetime.now(timezone.utc) - timedelta(days=days) if days is not None else None
turn_query = (
select(
VoiceTurn,
VoiceSession.user_id,
VoiceSession.final_story_id,
VoiceSession.id,
)
.join(VoiceSession, VoiceTurn.session_id == VoiceSession.id)
.where(
VoiceTurn.user_audio_path.isnot(None),
VoiceTurn.user_transcript.isnot(None),
)
)
failure_query = (
select(
VoiceSessionEvent,
VoiceSession.user_id,
VoiceSession.final_story_id,
VoiceSession.id,
)
.join(VoiceSession, VoiceSessionEvent.session_id == VoiceSession.id)
.where(VoiceSessionEvent.event_type == "turn_transcription_failed")
)
cost_query = select(
CostRecord.user_id,
CostRecord.provider_name,
CostRecord.estimated_cost,
).where(CostRecord.capability == "asr")
if cutoff is not None:
turn_query = turn_query.where(VoiceTurn.created_at >= cutoff)
failure_query = failure_query.where(VoiceSessionEvent.created_at >= cutoff)
cost_query = cost_query.where(CostRecord.timestamp >= cutoff)
turn_rows = (await db.execute(turn_query)).all()
failure_rows = (await db.execute(failure_query)).all()
cost_rows = (await db.execute(cost_query)).all()
costs_by_provider: dict[str, float] = {}
costs_by_user: dict[str, float] = {}
for user_id, provider_name, estimated_cost in cost_rows:
cost = float(estimated_cost or 0.0)
provider = str(provider_name or "unknown")
costs_by_provider[provider] = costs_by_provider.get(provider, 0.0) + cost
costs_by_user[str(user_id)] = costs_by_user.get(str(user_id), 0.0) + cost
provider_buckets: dict[tuple[str, str], dict[str, Any]] = {}
failure_reasons: dict[str, int] = {}
by_user: dict[str, dict[str, Any]] = {}
user_ids: set[str] = set()
story_ids: set[int] = set()
voice_session_ids: set[str] = set()
successful_calls = 0
failed_calls = 0
def provider_bucket(adapter: str) -> dict[str, Any]:
return provider_buckets.setdefault(
("asr", adapter),
{
"capability": "asr",
"adapter": adapter,
"call_count": 0,
"success_count": 0,
"failure_count": 0,
"avg_latency_ms": None,
"estimated_cost_usd": 0.0,
},
)
for turn, user_id, final_story_id, session_id in turn_rows:
user_id = str(user_id)
adapter = _voice_asr_provider_from_turn(turn)
user_ids.add(user_id)
voice_session_ids.add(str(session_id))
if final_story_id is not None:
story_ids.add(int(final_story_id))
bucket = provider_bucket(adapter)
bucket["call_count"] += 1
bucket["success_count"] += 1
successful_calls += 1
user_bucket = by_user.setdefault(user_id, _empty_admin_user_bucket(user_id))
user_bucket["call_count"] += 1
user_bucket["success_count"] += 1
if final_story_id is not None:
user_bucket["story_ids"].add(int(final_story_id))
for provider_name, cost in costs_by_provider.items():
key = ("asr", provider_name)
if key in provider_buckets:
provider_buckets[key]["estimated_cost_usd"] += cost
for user_id, cost in costs_by_user.items():
if user_id in by_user:
by_user[user_id]["estimated_cost_usd"] += cost
for event, user_id, final_story_id, session_id in failure_rows:
metadata = event.event_metadata or {}
adapter = str(
metadata.get("adapter")
or metadata.get("transcription_provider")
or "unknown"
)
user_id = str(user_id)
reason = str(metadata.get("error") or "unknown_error")
user_ids.add(user_id)
voice_session_ids.add(str(session_id))
if final_story_id is not None:
story_ids.add(int(final_story_id))
bucket = provider_bucket(adapter)
bucket["call_count"] += 1
bucket["failure_count"] += 1
failed_calls += 1
failure_reasons[reason] = failure_reasons.get(reason, 0) + 1
user_bucket = by_user.setdefault(user_id, _empty_admin_user_bucket(user_id))
user_bucket["call_count"] += 1
user_bucket["failure_count"] += 1
if final_story_id is not None:
user_bucket["story_ids"].add(int(final_story_id))
by_provider = [
{
**bucket,
"estimated_cost_usd": round(bucket["estimated_cost_usd"], 6),
}
for bucket in provider_buckets.values()
]
by_provider.sort(
key=lambda item: (
str(item["capability"]),
str(item["adapter"]),
)
)
return {
"total_calls": successful_calls + failed_calls,
"successful_calls": successful_calls,
"failed_calls": failed_calls,
"avg_latency_ms": None,
"estimated_cost_usd": round(
sum(float(bucket["estimated_cost_usd"]) for bucket in provider_buckets.values()),
6,
),
"by_provider": by_provider,
"failure_reasons": [
{"reason": reason, "count": count}
for reason, count in sorted(
failure_reasons.items(),
key=lambda item: (-item[1], item[0]),
)
],
"by_user": by_user,
"user_ids": user_ids,
"story_ids": story_ids,
"voice_session_ids": voice_session_ids,
"voice_turn_count": successful_calls,
}
async def get_admin_provider_analytics(
db: AsyncSession,
*,
days: int | None = None,
capability: str | None = None,
) -> dict[str, Any]:
"""Aggregate provider telemetry across every user in the current environment."""
rows = (await db.execute(_provider_events_query(days=days))).all()
events = [event for event, _, _ in rows]
filtered_rows = [
(event, user_id, story_id)
for event, user_id, story_id in rows
if _event_matches_capability(event, capability)
]
by_user: dict[str, dict[str, Any]] = {}
filtered_job_ids = {event.job_id for event, _, _ in filtered_rows}
filtered_story_ids = {
story_id for _, _, story_id in filtered_rows if story_id is not None
}
filtered_user_ids = {user_id for _, user_id, _ in filtered_rows}
for event, user_id, story_id in filtered_rows:
bucket = by_user.setdefault(
user_id,
_empty_admin_user_bucket(user_id),
)
bucket["call_count"] += 1
bucket["job_ids"].add(event.job_id)
if story_id is not None:
bucket["story_ids"].add(story_id)
if event.event_type == "provider_call_succeeded":
bucket["success_count"] += 1
bucket["estimated_cost_usd"] += (
_as_float((event.event_metadata or {}).get("estimated_cost_usd")) or 0.0
)
else:
bucket["failure_count"] += 1
provider_analytics = _aggregate_provider_events(events, capability=capability)
voice_session_count = 0
voice_turn_count = 0
if capability in {None, "asr"}:
asr_analytics = await _aggregate_voice_asr_provider_analytics(db, days=days)
provider_analytics = _merge_provider_analytics(
provider_analytics,
asr_analytics,
)
filtered_user_ids.update(asr_analytics["user_ids"])
filtered_story_ids.update(asr_analytics["story_ids"])
voice_session_count = len(asr_analytics["voice_session_ids"])
voice_turn_count = int(asr_analytics["voice_turn_count"])
for user_id, source_bucket in asr_analytics["by_user"].items():
target_bucket = by_user.setdefault(
user_id,
_empty_admin_user_bucket(user_id),
)
_merge_admin_user_bucket(target_bucket, source_bucket)
return {
"scope": "current_environment",
"window_days": days,
"capability": capability,
**provider_analytics,
"user_count": len(filtered_user_ids),
"job_count": len(filtered_job_ids),
"story_count": len(filtered_story_ids),
"voice_session_count": voice_session_count,
"voice_turn_count": voice_turn_count,
"by_user": _serialize_admin_user_buckets(by_user),
}