Implement unified story generation flow
This commit is contained in:
@@ -9,8 +9,8 @@ from app.core.admin_auth import admin_guard
|
||||
from app.db.admin_models import Provider
|
||||
from app.db.database import get_db
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
from app.services.admin_provider_analytics import get_admin_provider_analytics
|
||||
from app.services.cost_tracker import cost_tracker
|
||||
from app.services.generation_jobs import get_admin_provider_analytics
|
||||
from app.services.provider_policy import DEFAULT_PROVIDERS, list_capability_policies
|
||||
from app.services.secret_service import SecretService
|
||||
|
||||
@@ -97,6 +97,8 @@ class ProviderAnalyticsResponse(BaseModel):
|
||||
user_count: int
|
||||
job_count: int
|
||||
story_count: int
|
||||
voice_session_count: int = 0
|
||||
voice_turn_count: int = 0
|
||||
by_provider: list[ProviderAnalyticsBucket]
|
||||
by_user: list[ProviderAnalyticsUserBucket]
|
||||
failure_reasons: list[ProviderAnalyticsFailureReason]
|
||||
|
||||
@@ -1,15 +1,20 @@
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""应用全局配置"""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
)
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
BACKEND_DIR = Path(__file__).resolve().parents[2]
|
||||
BACKEND_ENV_FILE = BACKEND_DIR / ".env"
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""应用全局配置"""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=BACKEND_ENV_FILE,
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
# 应用基础配置
|
||||
app_name: str = "DreamWeaver"
|
||||
@@ -34,9 +39,10 @@ class Settings(BaseSettings):
|
||||
tts_api_key: str = ""
|
||||
image_api_key: str = ""
|
||||
|
||||
# Additional Provider API Keys
|
||||
openai_api_key: str = ""
|
||||
elevenlabs_api_key: str = ""
|
||||
# Additional Provider API Keys
|
||||
openai_api_key: str = ""
|
||||
openai_api_base: str = ""
|
||||
elevenlabs_api_key: str = ""
|
||||
cqtai_api_key: str = ""
|
||||
minimax_api_key: str = ""
|
||||
minimax_group_id: str = ""
|
||||
|
||||
@@ -9,6 +9,7 @@ from app.services.adapters.asr import openai as _asr_openai_adapter # noqa: F40
|
||||
from app.services.adapters.base import AdapterConfig, BaseAdapter
|
||||
|
||||
# Image adapters
|
||||
from app.services.adapters.image import antigravity as _image_antigravity_adapter # noqa: F401
|
||||
from app.services.adapters.image import cqtai as _image_cqtai_adapter # noqa: F401
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
|
||||
|
||||
@@ -2,10 +2,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from io import BytesIO
|
||||
|
||||
from fastapi import HTTPException
|
||||
from openai import AsyncOpenAI
|
||||
from openai import APIConnectionError, APIStatusError, APITimeoutError, AsyncOpenAI
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.services.adapters.asr.models import TranscriptionOutput
|
||||
@@ -15,6 +16,14 @@ from app.services.adapters.registry import AdapterRegistry
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _mask_openai_error(message: str) -> str:
|
||||
"""Avoid leaking bearer tokens while keeping ASR smoke failures actionable."""
|
||||
|
||||
sanitized = message.replace("\n", " ").strip()
|
||||
sanitized = re.sub(r"Bearer\s+[A-Za-z0-9._-]+", "Bearer ***", sanitized)
|
||||
return re.sub(r"sk-[A-Za-z0-9_-]+", "sk-***", sanitized)
|
||||
|
||||
|
||||
@AdapterRegistry.register("asr", "openai_asr")
|
||||
class OpenAIASRAdapter(BaseAdapter[TranscriptionOutput]):
|
||||
"""Transcribe uploaded voice turn audio with OpenAI audio transcription."""
|
||||
@@ -37,7 +46,11 @@ class OpenAIASRAdapter(BaseAdapter[TranscriptionOutput]):
|
||||
detail="OPENAI_API_KEY 未配置,无法使用 OpenAI 语音转写。",
|
||||
)
|
||||
|
||||
client = AsyncOpenAI(api_key=self.config.api_key)
|
||||
client = AsyncOpenAI(
|
||||
api_key=self.config.api_key,
|
||||
base_url=self.config.api_base or None,
|
||||
timeout=self.config.timeout_ms / 1000,
|
||||
)
|
||||
audio_file = BytesIO(audio_bytes)
|
||||
audio_file.name = file_name or "voice-turn.webm"
|
||||
|
||||
@@ -51,11 +64,29 @@ class OpenAIASRAdapter(BaseAdapter[TranscriptionOutput]):
|
||||
language=language,
|
||||
prompt=prompt,
|
||||
)
|
||||
except APIStatusError as exc:
|
||||
detail = _mask_openai_error(getattr(exc, "message", str(exc)))
|
||||
logger.warning(
|
||||
"openai_asr_failed",
|
||||
status_code=exc.status_code,
|
||||
error=detail,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=f"OpenAI ASR 调用失败(HTTP {exc.status_code}):{detail}",
|
||||
) from exc
|
||||
except (APITimeoutError, APIConnectionError) as exc:
|
||||
detail = _mask_openai_error(str(exc))
|
||||
logger.warning("openai_asr_failed", error=detail)
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=f"OpenAI ASR 网络连接失败:{detail}",
|
||||
) from exc
|
||||
except Exception as exc:
|
||||
logger.warning("openai_asr_failed", error=str(exc))
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="语音转写服务暂时不可用,请稍后重试。",
|
||||
detail=f"OpenAI ASR 调用异常:{_mask_openai_error(str(exc))}",
|
||||
) from exc
|
||||
|
||||
transcript_text = (getattr(response, "text", "") or "").strip()
|
||||
|
||||
@@ -126,6 +126,11 @@ class MiniMaxTTSAdapter(BaseAdapter[bytes]):
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def estimated_cost(self) -> float:
|
||||
"""预估每次短文本语音合成成本 (USD)。"""
|
||||
return 0.01
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||
|
||||
408
backend/app/services/admin_provider_analytics.py
Normal file
408
backend/app/services/admin_provider_analytics.py
Normal file
@@ -0,0 +1,408 @@
|
||||
"""Admin-facing provider analytics across generation and voice telemetry."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.admin_models import CostRecord
|
||||
from app.db.models import VoiceSession, VoiceSessionEvent, VoiceTurn
|
||||
from app.services.generation_jobs import (
|
||||
_aggregate_provider_events,
|
||||
_as_float,
|
||||
_event_matches_capability,
|
||||
_provider_events_query,
|
||||
)
|
||||
|
||||
|
||||
def _empty_admin_user_bucket(user_id: str) -> dict[str, Any]:
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"call_count": 0,
|
||||
"success_count": 0,
|
||||
"failure_count": 0,
|
||||
"estimated_cost_usd": 0.0,
|
||||
"job_ids": set(),
|
||||
"story_ids": set(),
|
||||
}
|
||||
|
||||
|
||||
def _merge_admin_user_bucket(
|
||||
target: dict[str, Any],
|
||||
source: dict[str, Any],
|
||||
) -> None:
|
||||
target["call_count"] += int(source["call_count"])
|
||||
target["success_count"] += int(source["success_count"])
|
||||
target["failure_count"] += int(source["failure_count"])
|
||||
target["estimated_cost_usd"] += float(source["estimated_cost_usd"])
|
||||
target["job_ids"].update(source["job_ids"])
|
||||
target["story_ids"].update(source["story_ids"])
|
||||
|
||||
|
||||
def _serialize_admin_user_buckets(
|
||||
by_user: dict[str, dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
serialized_users = [
|
||||
{
|
||||
"user_id": user_id,
|
||||
"call_count": bucket["call_count"],
|
||||
"success_count": bucket["success_count"],
|
||||
"failure_count": bucket["failure_count"],
|
||||
"job_count": len(bucket["job_ids"]),
|
||||
"story_count": len(bucket["story_ids"]),
|
||||
"estimated_cost_usd": round(bucket["estimated_cost_usd"], 6),
|
||||
}
|
||||
for user_id, bucket in by_user.items()
|
||||
]
|
||||
serialized_users.sort(
|
||||
key=lambda item: (
|
||||
-int(item["call_count"]),
|
||||
-float(item["estimated_cost_usd"]),
|
||||
str(item["user_id"]),
|
||||
)
|
||||
)
|
||||
return serialized_users
|
||||
|
||||
|
||||
def _merge_provider_analytics(
|
||||
left: dict[str, Any],
|
||||
right: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
provider_buckets: dict[tuple[str, str], dict[str, Any]] = {}
|
||||
latency_totals: dict[tuple[str, str], float] = {}
|
||||
latency_counts: dict[tuple[str, str], int] = {}
|
||||
failure_reasons: dict[str, int] = {}
|
||||
|
||||
for payload in (left, right):
|
||||
for row in payload["by_provider"]:
|
||||
capability_name = str(row["capability"])
|
||||
adapter_name = str(row["adapter"])
|
||||
key = (capability_name, adapter_name)
|
||||
bucket = provider_buckets.setdefault(
|
||||
key,
|
||||
{
|
||||
"capability": capability_name,
|
||||
"adapter": adapter_name,
|
||||
"call_count": 0,
|
||||
"success_count": 0,
|
||||
"failure_count": 0,
|
||||
"estimated_cost_usd": 0.0,
|
||||
},
|
||||
)
|
||||
call_count = int(row["call_count"])
|
||||
bucket["call_count"] += call_count
|
||||
bucket["success_count"] += int(row["success_count"])
|
||||
bucket["failure_count"] += int(row["failure_count"])
|
||||
bucket["estimated_cost_usd"] += float(row["estimated_cost_usd"])
|
||||
|
||||
if row["avg_latency_ms"] is not None and call_count:
|
||||
latency_totals[key] = latency_totals.get(key, 0.0) + (
|
||||
float(row["avg_latency_ms"]) * call_count
|
||||
)
|
||||
latency_counts[key] = latency_counts.get(key, 0) + call_count
|
||||
|
||||
for item in payload["failure_reasons"]:
|
||||
reason = str(item["reason"])
|
||||
failure_reasons[reason] = failure_reasons.get(reason, 0) + int(item["count"])
|
||||
|
||||
by_provider = []
|
||||
total_latency = 0.0
|
||||
latency_count = 0
|
||||
for key, bucket in provider_buckets.items():
|
||||
bucket_latency_count = latency_counts.get(key, 0)
|
||||
bucket_latency_total = latency_totals.get(key, 0.0)
|
||||
if bucket_latency_count:
|
||||
total_latency += bucket_latency_total
|
||||
latency_count += bucket_latency_count
|
||||
by_provider.append(
|
||||
{
|
||||
**bucket,
|
||||
"avg_latency_ms": (
|
||||
round(bucket_latency_total / bucket_latency_count, 2)
|
||||
if bucket_latency_count
|
||||
else None
|
||||
),
|
||||
"estimated_cost_usd": round(bucket["estimated_cost_usd"], 6),
|
||||
}
|
||||
)
|
||||
|
||||
by_provider.sort(
|
||||
key=lambda item: (
|
||||
str(item["capability"]),
|
||||
str(item["adapter"]),
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
"total_calls": int(left["total_calls"]) + int(right["total_calls"]),
|
||||
"successful_calls": int(left["successful_calls"]) + int(right["successful_calls"]),
|
||||
"failed_calls": int(left["failed_calls"]) + int(right["failed_calls"]),
|
||||
"avg_latency_ms": round(total_latency / latency_count, 2) if latency_count else None,
|
||||
"estimated_cost_usd": round(
|
||||
float(left["estimated_cost_usd"]) + float(right["estimated_cost_usd"]),
|
||||
6,
|
||||
),
|
||||
"by_provider": by_provider,
|
||||
"failure_reasons": [
|
||||
{"reason": reason, "count": count}
|
||||
for reason, count in sorted(
|
||||
failure_reasons.items(),
|
||||
key=lambda item: (-item[1], item[0]),
|
||||
)
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _voice_asr_provider_from_turn(turn: VoiceTurn) -> str:
|
||||
story_patch = turn.story_patch or {}
|
||||
return str(story_patch.get("transcription_provider") or "unknown")
|
||||
|
||||
|
||||
async def _aggregate_voice_asr_provider_analytics(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
days: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Aggregate ASR telemetry from voice co-creation sessions."""
|
||||
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=days) if days is not None else None
|
||||
|
||||
turn_query = (
|
||||
select(
|
||||
VoiceTurn,
|
||||
VoiceSession.user_id,
|
||||
VoiceSession.final_story_id,
|
||||
VoiceSession.id,
|
||||
)
|
||||
.join(VoiceSession, VoiceTurn.session_id == VoiceSession.id)
|
||||
.where(
|
||||
VoiceTurn.user_audio_path.isnot(None),
|
||||
VoiceTurn.user_transcript.isnot(None),
|
||||
)
|
||||
)
|
||||
failure_query = (
|
||||
select(
|
||||
VoiceSessionEvent,
|
||||
VoiceSession.user_id,
|
||||
VoiceSession.final_story_id,
|
||||
VoiceSession.id,
|
||||
)
|
||||
.join(VoiceSession, VoiceSessionEvent.session_id == VoiceSession.id)
|
||||
.where(VoiceSessionEvent.event_type == "turn_transcription_failed")
|
||||
)
|
||||
cost_query = select(
|
||||
CostRecord.user_id,
|
||||
CostRecord.provider_name,
|
||||
CostRecord.estimated_cost,
|
||||
).where(CostRecord.capability == "asr")
|
||||
|
||||
if cutoff is not None:
|
||||
turn_query = turn_query.where(VoiceTurn.created_at >= cutoff)
|
||||
failure_query = failure_query.where(VoiceSessionEvent.created_at >= cutoff)
|
||||
cost_query = cost_query.where(CostRecord.timestamp >= cutoff)
|
||||
|
||||
turn_rows = (await db.execute(turn_query)).all()
|
||||
failure_rows = (await db.execute(failure_query)).all()
|
||||
cost_rows = (await db.execute(cost_query)).all()
|
||||
|
||||
costs_by_provider: dict[str, float] = {}
|
||||
costs_by_user: dict[str, float] = {}
|
||||
for user_id, provider_name, estimated_cost in cost_rows:
|
||||
cost = float(estimated_cost or 0.0)
|
||||
provider = str(provider_name or "unknown")
|
||||
costs_by_provider[provider] = costs_by_provider.get(provider, 0.0) + cost
|
||||
costs_by_user[str(user_id)] = costs_by_user.get(str(user_id), 0.0) + cost
|
||||
|
||||
provider_buckets: dict[tuple[str, str], dict[str, Any]] = {}
|
||||
failure_reasons: dict[str, int] = {}
|
||||
by_user: dict[str, dict[str, Any]] = {}
|
||||
user_ids: set[str] = set()
|
||||
story_ids: set[int] = set()
|
||||
voice_session_ids: set[str] = set()
|
||||
successful_calls = 0
|
||||
failed_calls = 0
|
||||
|
||||
def provider_bucket(adapter: str) -> dict[str, Any]:
|
||||
return provider_buckets.setdefault(
|
||||
("asr", adapter),
|
||||
{
|
||||
"capability": "asr",
|
||||
"adapter": adapter,
|
||||
"call_count": 0,
|
||||
"success_count": 0,
|
||||
"failure_count": 0,
|
||||
"avg_latency_ms": None,
|
||||
"estimated_cost_usd": 0.0,
|
||||
},
|
||||
)
|
||||
|
||||
for turn, user_id, final_story_id, session_id in turn_rows:
|
||||
user_id = str(user_id)
|
||||
adapter = _voice_asr_provider_from_turn(turn)
|
||||
user_ids.add(user_id)
|
||||
voice_session_ids.add(str(session_id))
|
||||
if final_story_id is not None:
|
||||
story_ids.add(int(final_story_id))
|
||||
|
||||
bucket = provider_bucket(adapter)
|
||||
bucket["call_count"] += 1
|
||||
bucket["success_count"] += 1
|
||||
successful_calls += 1
|
||||
|
||||
user_bucket = by_user.setdefault(user_id, _empty_admin_user_bucket(user_id))
|
||||
user_bucket["call_count"] += 1
|
||||
user_bucket["success_count"] += 1
|
||||
if final_story_id is not None:
|
||||
user_bucket["story_ids"].add(int(final_story_id))
|
||||
|
||||
for provider_name, cost in costs_by_provider.items():
|
||||
key = ("asr", provider_name)
|
||||
if key in provider_buckets:
|
||||
provider_buckets[key]["estimated_cost_usd"] += cost
|
||||
|
||||
for user_id, cost in costs_by_user.items():
|
||||
if user_id in by_user:
|
||||
by_user[user_id]["estimated_cost_usd"] += cost
|
||||
|
||||
for event, user_id, final_story_id, session_id in failure_rows:
|
||||
metadata = event.event_metadata or {}
|
||||
adapter = str(
|
||||
metadata.get("adapter")
|
||||
or metadata.get("transcription_provider")
|
||||
or "unknown"
|
||||
)
|
||||
user_id = str(user_id)
|
||||
reason = str(metadata.get("error") or "unknown_error")
|
||||
user_ids.add(user_id)
|
||||
voice_session_ids.add(str(session_id))
|
||||
if final_story_id is not None:
|
||||
story_ids.add(int(final_story_id))
|
||||
|
||||
bucket = provider_bucket(adapter)
|
||||
bucket["call_count"] += 1
|
||||
bucket["failure_count"] += 1
|
||||
failed_calls += 1
|
||||
failure_reasons[reason] = failure_reasons.get(reason, 0) + 1
|
||||
|
||||
user_bucket = by_user.setdefault(user_id, _empty_admin_user_bucket(user_id))
|
||||
user_bucket["call_count"] += 1
|
||||
user_bucket["failure_count"] += 1
|
||||
if final_story_id is not None:
|
||||
user_bucket["story_ids"].add(int(final_story_id))
|
||||
|
||||
by_provider = [
|
||||
{
|
||||
**bucket,
|
||||
"estimated_cost_usd": round(bucket["estimated_cost_usd"], 6),
|
||||
}
|
||||
for bucket in provider_buckets.values()
|
||||
]
|
||||
by_provider.sort(
|
||||
key=lambda item: (
|
||||
str(item["capability"]),
|
||||
str(item["adapter"]),
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
"total_calls": successful_calls + failed_calls,
|
||||
"successful_calls": successful_calls,
|
||||
"failed_calls": failed_calls,
|
||||
"avg_latency_ms": None,
|
||||
"estimated_cost_usd": round(
|
||||
sum(float(bucket["estimated_cost_usd"]) for bucket in provider_buckets.values()),
|
||||
6,
|
||||
),
|
||||
"by_provider": by_provider,
|
||||
"failure_reasons": [
|
||||
{"reason": reason, "count": count}
|
||||
for reason, count in sorted(
|
||||
failure_reasons.items(),
|
||||
key=lambda item: (-item[1], item[0]),
|
||||
)
|
||||
],
|
||||
"by_user": by_user,
|
||||
"user_ids": user_ids,
|
||||
"story_ids": story_ids,
|
||||
"voice_session_ids": voice_session_ids,
|
||||
"voice_turn_count": successful_calls,
|
||||
}
|
||||
|
||||
|
||||
async def get_admin_provider_analytics(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
days: int | None = None,
|
||||
capability: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Aggregate provider telemetry across every user in the current environment."""
|
||||
|
||||
rows = (await db.execute(_provider_events_query(days=days))).all()
|
||||
events = [event for event, _, _ in rows]
|
||||
filtered_rows = [
|
||||
(event, user_id, story_id)
|
||||
for event, user_id, story_id in rows
|
||||
if _event_matches_capability(event, capability)
|
||||
]
|
||||
|
||||
by_user: dict[str, dict[str, Any]] = {}
|
||||
filtered_job_ids = {event.job_id for event, _, _ in filtered_rows}
|
||||
filtered_story_ids = {
|
||||
story_id for _, _, story_id in filtered_rows if story_id is not None
|
||||
}
|
||||
filtered_user_ids = {user_id for _, user_id, _ in filtered_rows}
|
||||
|
||||
for event, user_id, story_id in filtered_rows:
|
||||
bucket = by_user.setdefault(
|
||||
user_id,
|
||||
_empty_admin_user_bucket(user_id),
|
||||
)
|
||||
bucket["call_count"] += 1
|
||||
bucket["job_ids"].add(event.job_id)
|
||||
if story_id is not None:
|
||||
bucket["story_ids"].add(story_id)
|
||||
|
||||
if event.event_type == "provider_call_succeeded":
|
||||
bucket["success_count"] += 1
|
||||
bucket["estimated_cost_usd"] += (
|
||||
_as_float((event.event_metadata or {}).get("estimated_cost_usd")) or 0.0
|
||||
)
|
||||
else:
|
||||
bucket["failure_count"] += 1
|
||||
|
||||
provider_analytics = _aggregate_provider_events(events, capability=capability)
|
||||
voice_session_count = 0
|
||||
voice_turn_count = 0
|
||||
if capability in {None, "asr"}:
|
||||
asr_analytics = await _aggregate_voice_asr_provider_analytics(db, days=days)
|
||||
provider_analytics = _merge_provider_analytics(
|
||||
provider_analytics,
|
||||
asr_analytics,
|
||||
)
|
||||
filtered_user_ids.update(asr_analytics["user_ids"])
|
||||
filtered_story_ids.update(asr_analytics["story_ids"])
|
||||
voice_session_count = len(asr_analytics["voice_session_ids"])
|
||||
voice_turn_count = int(asr_analytics["voice_turn_count"])
|
||||
|
||||
for user_id, source_bucket in asr_analytics["by_user"].items():
|
||||
target_bucket = by_user.setdefault(
|
||||
user_id,
|
||||
_empty_admin_user_bucket(user_id),
|
||||
)
|
||||
_merge_admin_user_bucket(target_bucket, source_bucket)
|
||||
|
||||
return {
|
||||
"scope": "current_environment",
|
||||
"window_days": days,
|
||||
"capability": capability,
|
||||
**provider_analytics,
|
||||
"user_count": len(filtered_user_ids),
|
||||
"job_count": len(filtered_job_ids),
|
||||
"story_count": len(filtered_story_ids),
|
||||
"voice_session_count": voice_session_count,
|
||||
"voice_turn_count": voice_turn_count,
|
||||
"by_user": _serialize_admin_user_buckets(by_user),
|
||||
}
|
||||
@@ -11,7 +11,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
from app.db.models import GenerationJob, GenerationJobEvent, Story
|
||||
from app.db.models import (
|
||||
GenerationJob,
|
||||
GenerationJobEvent,
|
||||
Story,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -712,87 +716,6 @@ async def get_user_provider_analytics(
|
||||
}
|
||||
|
||||
|
||||
async def get_admin_provider_analytics(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
days: int | None = None,
|
||||
capability: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Aggregate provider telemetry across every user in the current environment."""
|
||||
|
||||
rows = (await db.execute(_provider_events_query(days=days))).all()
|
||||
events = [event for event, _, _ in rows]
|
||||
filtered_rows = [
|
||||
(event, user_id, story_id)
|
||||
for event, user_id, story_id in rows
|
||||
if _event_matches_capability(event, capability)
|
||||
]
|
||||
|
||||
by_user: dict[str, dict[str, Any]] = {}
|
||||
filtered_job_ids = {event.job_id for event, _, _ in filtered_rows}
|
||||
filtered_story_ids = {
|
||||
story_id for _, _, story_id in filtered_rows if story_id is not None
|
||||
}
|
||||
filtered_user_ids = {user_id for _, user_id, _ in filtered_rows}
|
||||
|
||||
for event, user_id, story_id in filtered_rows:
|
||||
bucket = by_user.setdefault(
|
||||
user_id,
|
||||
{
|
||||
"user_id": user_id,
|
||||
"call_count": 0,
|
||||
"success_count": 0,
|
||||
"failure_count": 0,
|
||||
"estimated_cost_usd": 0.0,
|
||||
"job_ids": set(),
|
||||
"story_ids": set(),
|
||||
},
|
||||
)
|
||||
bucket["call_count"] += 1
|
||||
bucket["job_ids"].add(event.job_id)
|
||||
if story_id is not None:
|
||||
bucket["story_ids"].add(story_id)
|
||||
|
||||
if event.event_type == "provider_call_succeeded":
|
||||
bucket["success_count"] += 1
|
||||
bucket["estimated_cost_usd"] += (
|
||||
_as_float((event.event_metadata or {}).get("estimated_cost_usd")) or 0.0
|
||||
)
|
||||
else:
|
||||
bucket["failure_count"] += 1
|
||||
|
||||
serialized_users = [
|
||||
{
|
||||
"user_id": user_id,
|
||||
"call_count": bucket["call_count"],
|
||||
"success_count": bucket["success_count"],
|
||||
"failure_count": bucket["failure_count"],
|
||||
"job_count": len(bucket["job_ids"]),
|
||||
"story_count": len(bucket["story_ids"]),
|
||||
"estimated_cost_usd": round(bucket["estimated_cost_usd"], 6),
|
||||
}
|
||||
for user_id, bucket in by_user.items()
|
||||
]
|
||||
serialized_users.sort(
|
||||
key=lambda item: (
|
||||
-int(item["call_count"]),
|
||||
-float(item["estimated_cost_usd"]),
|
||||
str(item["user_id"]),
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
"scope": "current_environment",
|
||||
"window_days": days,
|
||||
"capability": capability,
|
||||
**_aggregate_provider_events(events, capability=capability),
|
||||
"user_count": len(filtered_user_ids),
|
||||
"job_count": len(filtered_job_ids),
|
||||
"story_count": len(filtered_story_ids),
|
||||
"by_user": serialized_users,
|
||||
}
|
||||
|
||||
|
||||
async def get_user_generation_ops_summary(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
|
||||
@@ -117,6 +117,7 @@ def _get_default_config(adapter_name: str) -> AdapterConfig | None:
|
||||
if adapter_name == "openai_asr":
|
||||
return AdapterConfig(
|
||||
api_key=settings.openai_api_key,
|
||||
api_base=getattr(settings, "openai_api_base", ""),
|
||||
model=settings.voice_transcription_model,
|
||||
timeout_ms=60000,
|
||||
)
|
||||
@@ -131,6 +132,7 @@ def _get_default_config(adapter_name: str) -> AdapterConfig | None:
|
||||
if adapter_name == "openai":
|
||||
return AdapterConfig(
|
||||
api_key=getattr(settings, "openai_api_key", ""),
|
||||
api_base=getattr(settings, "openai_api_base", ""),
|
||||
model=settings.openai_model,
|
||||
timeout_ms=60000,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user