Implement unified story generation flow

This commit is contained in:
2026-06-18 14:48:27 +08:00
parent 0ccfd00a23
commit 7ebdfb2582
27 changed files with 1323 additions and 215 deletions

View File

@@ -2,25 +2,24 @@
# DREAMWEAVER 环境变量配置模板
# ==============================================
# 使用说明:
# 1. 复制此文件为 .env
# 1. 在仓库根目录执行cp backend/.env.example backend/.env
# 2. 填入您的 API Keys
# 3. 配合 docker-compose.yml 启动
# 3. 后端、Celery、Docker demo 都读取 backend/.env
# 4. 仓库根目录 .env 仅供 Docker Compose 自身读取构建参数,不放后端密钥
# ==============================================
# ----------------------------------------------
# 1. 基础设施 (Infrastructure) [必填]
# ----------------------------------------------
# ⚠️ Docker 启动时无需修改这部分,直接使用默认值即可
# ⚠️ 仅当您想连接外部数据库时才修改这里
# ⚠️ Docker 演示通常无需修改这部分,直接使用默认值即可
# ⚠️ 本机直跑后端时,把 DATABASE_URL/CELERY_* 改成文件末尾的 localhost 版本
POSTGRES_USER=dreamweaver
POSTGRES_PASSWORD=dreamweaver_password
POSTGRES_DB=dreamweaver_db
POSTGRES_PORT=5432
REDIS_PORT=6379
DATABASE_URL=postgresql+asyncpg://${POSTGRES_USER}:${POSTGRES_PASSWORD}@db:5432/${POSTGRES_DB}
DATABASE_URL=postgresql+asyncpg://dreamweaver:dreamweaver_password@db:5432/dreamweaver_db
CELERY_BROKER_URL=redis://redis:6379/0
CELERY_RESULT_BACKEND=redis://redis:6379/0
REDIS_URL=redis://redis:6379/0
# Web Security
SECRET_KEY=change-me-to-a-secure-random-string-in-production
@@ -44,6 +43,7 @@ TTS_PROVIDERS=["minimax", "elevenlabs", "edge_tts"]
# 绘本结构生成: 默认复用 Gemini Storybook adapter
STORYBOOK_PROVIDERS=["storybook_primary"]
# 语音识别: 本地演示默认 demo真实转写可设置为 ["openai_asr", "demo"]
# 真实 ASR smoke 必须让 openai_asr 排在 demo 前面,否则 demo hint 路径会先命中。
ASR_PROVIDERS=["demo"]
# [模型参数]
@@ -83,8 +83,10 @@ ELEVENLABS_API_KEY=
# OpenAI (如需使用)
OPENAI_API_KEY=
# 可选OpenAI 官方地址可留空;使用兼容网关时填类似 https://example.com/v1
OPENAI_API_BASE=
# OpenAI ASR
VOICE_TRANSCRIPTION_MODE=provider
VOICE_TRANSCRIPTION_MODEL=gpt-4o-mini-transcribe
VOICE_TRANSCRIPTION_LANGUAGE=zh
@@ -122,6 +124,8 @@ CORS_ORIGINS=["http://localhost:52080", "http://localhost:52888", "http://localh
# [本地开发覆盖 Local Dev Override]
# 如果您不使用 Docker而是在本机直接运行 `python -m uvicorn ...`
# 请取消注释以下行以连接 localhost 数据库:
# 请改用以下值连接 localhost 数据库/Redis
# DATABASE_URL=postgresql+asyncpg://dreamweaver:dreamweaver_password@localhost:52432/dreamweaver_db
# CELERY_BROKER_URL=redis://localhost:52379/0
# CELERY_RESULT_BACKEND=redis://localhost:52379/0
# REDIS_URL=redis://localhost:52379/0

View File

@@ -1,4 +1,5 @@
FROM python:3.11-slim
ARG PYTHON_BASE_IMAGE=python:3.11-slim
FROM ${PYTHON_BASE_IMAGE}
WORKDIR /app

View File

@@ -9,8 +9,8 @@ from app.core.admin_auth import admin_guard
from app.db.admin_models import Provider
from app.db.database import get_db
from app.services.adapters.registry import AdapterRegistry
from app.services.admin_provider_analytics import get_admin_provider_analytics
from app.services.cost_tracker import cost_tracker
from app.services.generation_jobs import get_admin_provider_analytics
from app.services.provider_policy import DEFAULT_PROVIDERS, list_capability_policies
from app.services.secret_service import SecretService
@@ -97,6 +97,8 @@ class ProviderAnalyticsResponse(BaseModel):
user_count: int
job_count: int
story_count: int
voice_session_count: int = 0
voice_turn_count: int = 0
by_provider: list[ProviderAnalyticsBucket]
by_user: list[ProviderAnalyticsUserBucket]
failure_reasons: list[ProviderAnalyticsFailureReason]

View File

@@ -1,15 +1,20 @@
from pydantic import Field, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""应用全局配置"""
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
extra="ignore",
)
from pathlib import Path
from pydantic import Field, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
BACKEND_DIR = Path(__file__).resolve().parents[2]
BACKEND_ENV_FILE = BACKEND_DIR / ".env"
class Settings(BaseSettings):
"""应用全局配置"""
model_config = SettingsConfigDict(
env_file=BACKEND_ENV_FILE,
env_file_encoding="utf-8",
extra="ignore",
)
# 应用基础配置
app_name: str = "DreamWeaver"
@@ -34,9 +39,10 @@ class Settings(BaseSettings):
tts_api_key: str = ""
image_api_key: str = ""
# Additional Provider API Keys
openai_api_key: str = ""
elevenlabs_api_key: str = ""
# Additional Provider API Keys
openai_api_key: str = ""
openai_api_base: str = ""
elevenlabs_api_key: str = ""
cqtai_api_key: str = ""
minimax_api_key: str = ""
minimax_group_id: str = ""

View File

@@ -9,6 +9,7 @@ from app.services.adapters.asr import openai as _asr_openai_adapter # noqa: F40
from app.services.adapters.base import AdapterConfig, BaseAdapter
# Image adapters
from app.services.adapters.image import antigravity as _image_antigravity_adapter # noqa: F401
from app.services.adapters.image import cqtai as _image_cqtai_adapter # noqa: F401
from app.services.adapters.registry import AdapterRegistry

View File

@@ -2,10 +2,11 @@
from __future__ import annotations
import re
from io import BytesIO
from fastapi import HTTPException
from openai import AsyncOpenAI
from openai import APIConnectionError, APIStatusError, APITimeoutError, AsyncOpenAI
from app.core.logging import get_logger
from app.services.adapters.asr.models import TranscriptionOutput
@@ -15,6 +16,14 @@ from app.services.adapters.registry import AdapterRegistry
logger = get_logger(__name__)
def _mask_openai_error(message: str) -> str:
"""Avoid leaking bearer tokens while keeping ASR smoke failures actionable."""
sanitized = message.replace("\n", " ").strip()
sanitized = re.sub(r"Bearer\s+[A-Za-z0-9._-]+", "Bearer ***", sanitized)
return re.sub(r"sk-[A-Za-z0-9_-]+", "sk-***", sanitized)
@AdapterRegistry.register("asr", "openai_asr")
class OpenAIASRAdapter(BaseAdapter[TranscriptionOutput]):
"""Transcribe uploaded voice turn audio with OpenAI audio transcription."""
@@ -37,7 +46,11 @@ class OpenAIASRAdapter(BaseAdapter[TranscriptionOutput]):
detail="OPENAI_API_KEY 未配置,无法使用 OpenAI 语音转写。",
)
client = AsyncOpenAI(api_key=self.config.api_key)
client = AsyncOpenAI(
api_key=self.config.api_key,
base_url=self.config.api_base or None,
timeout=self.config.timeout_ms / 1000,
)
audio_file = BytesIO(audio_bytes)
audio_file.name = file_name or "voice-turn.webm"
@@ -51,11 +64,29 @@ class OpenAIASRAdapter(BaseAdapter[TranscriptionOutput]):
language=language,
prompt=prompt,
)
except APIStatusError as exc:
detail = _mask_openai_error(getattr(exc, "message", str(exc)))
logger.warning(
"openai_asr_failed",
status_code=exc.status_code,
error=detail,
)
raise HTTPException(
status_code=503,
detail=f"OpenAI ASR 调用失败HTTP {exc.status_code}{detail}",
) from exc
except (APITimeoutError, APIConnectionError) as exc:
detail = _mask_openai_error(str(exc))
logger.warning("openai_asr_failed", error=detail)
raise HTTPException(
status_code=503,
detail=f"OpenAI ASR 网络连接失败:{detail}",
) from exc
except Exception as exc:
logger.warning("openai_asr_failed", error=str(exc))
raise HTTPException(
status_code=503,
detail="语音转写服务暂时不可用,请稍后重试。",
detail=f"OpenAI ASR 调用异常:{_mask_openai_error(str(exc))}",
) from exc
transcript_text = (getattr(response, "text", "") or "").strip()

View File

@@ -126,6 +126,11 @@ class MiniMaxTTSAdapter(BaseAdapter[bytes]):
except Exception:
return False
@property
def estimated_cost(self) -> float:
"""预估每次短文本语音合成成本 (USD)。"""
return 0.01
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=1, max=10),

View File

@@ -0,0 +1,408 @@
"""Admin-facing provider analytics across generation and voice telemetry."""
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from typing import Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.admin_models import CostRecord
from app.db.models import VoiceSession, VoiceSessionEvent, VoiceTurn
from app.services.generation_jobs import (
_aggregate_provider_events,
_as_float,
_event_matches_capability,
_provider_events_query,
)
def _empty_admin_user_bucket(user_id: str) -> dict[str, Any]:
return {
"user_id": user_id,
"call_count": 0,
"success_count": 0,
"failure_count": 0,
"estimated_cost_usd": 0.0,
"job_ids": set(),
"story_ids": set(),
}
def _merge_admin_user_bucket(
target: dict[str, Any],
source: dict[str, Any],
) -> None:
target["call_count"] += int(source["call_count"])
target["success_count"] += int(source["success_count"])
target["failure_count"] += int(source["failure_count"])
target["estimated_cost_usd"] += float(source["estimated_cost_usd"])
target["job_ids"].update(source["job_ids"])
target["story_ids"].update(source["story_ids"])
def _serialize_admin_user_buckets(
by_user: dict[str, dict[str, Any]],
) -> list[dict[str, Any]]:
serialized_users = [
{
"user_id": user_id,
"call_count": bucket["call_count"],
"success_count": bucket["success_count"],
"failure_count": bucket["failure_count"],
"job_count": len(bucket["job_ids"]),
"story_count": len(bucket["story_ids"]),
"estimated_cost_usd": round(bucket["estimated_cost_usd"], 6),
}
for user_id, bucket in by_user.items()
]
serialized_users.sort(
key=lambda item: (
-int(item["call_count"]),
-float(item["estimated_cost_usd"]),
str(item["user_id"]),
)
)
return serialized_users
def _merge_provider_analytics(
left: dict[str, Any],
right: dict[str, Any],
) -> dict[str, Any]:
provider_buckets: dict[tuple[str, str], dict[str, Any]] = {}
latency_totals: dict[tuple[str, str], float] = {}
latency_counts: dict[tuple[str, str], int] = {}
failure_reasons: dict[str, int] = {}
for payload in (left, right):
for row in payload["by_provider"]:
capability_name = str(row["capability"])
adapter_name = str(row["adapter"])
key = (capability_name, adapter_name)
bucket = provider_buckets.setdefault(
key,
{
"capability": capability_name,
"adapter": adapter_name,
"call_count": 0,
"success_count": 0,
"failure_count": 0,
"estimated_cost_usd": 0.0,
},
)
call_count = int(row["call_count"])
bucket["call_count"] += call_count
bucket["success_count"] += int(row["success_count"])
bucket["failure_count"] += int(row["failure_count"])
bucket["estimated_cost_usd"] += float(row["estimated_cost_usd"])
if row["avg_latency_ms"] is not None and call_count:
latency_totals[key] = latency_totals.get(key, 0.0) + (
float(row["avg_latency_ms"]) * call_count
)
latency_counts[key] = latency_counts.get(key, 0) + call_count
for item in payload["failure_reasons"]:
reason = str(item["reason"])
failure_reasons[reason] = failure_reasons.get(reason, 0) + int(item["count"])
by_provider = []
total_latency = 0.0
latency_count = 0
for key, bucket in provider_buckets.items():
bucket_latency_count = latency_counts.get(key, 0)
bucket_latency_total = latency_totals.get(key, 0.0)
if bucket_latency_count:
total_latency += bucket_latency_total
latency_count += bucket_latency_count
by_provider.append(
{
**bucket,
"avg_latency_ms": (
round(bucket_latency_total / bucket_latency_count, 2)
if bucket_latency_count
else None
),
"estimated_cost_usd": round(bucket["estimated_cost_usd"], 6),
}
)
by_provider.sort(
key=lambda item: (
str(item["capability"]),
str(item["adapter"]),
)
)
return {
"total_calls": int(left["total_calls"]) + int(right["total_calls"]),
"successful_calls": int(left["successful_calls"]) + int(right["successful_calls"]),
"failed_calls": int(left["failed_calls"]) + int(right["failed_calls"]),
"avg_latency_ms": round(total_latency / latency_count, 2) if latency_count else None,
"estimated_cost_usd": round(
float(left["estimated_cost_usd"]) + float(right["estimated_cost_usd"]),
6,
),
"by_provider": by_provider,
"failure_reasons": [
{"reason": reason, "count": count}
for reason, count in sorted(
failure_reasons.items(),
key=lambda item: (-item[1], item[0]),
)
],
}
def _voice_asr_provider_from_turn(turn: VoiceTurn) -> str:
story_patch = turn.story_patch or {}
return str(story_patch.get("transcription_provider") or "unknown")
async def _aggregate_voice_asr_provider_analytics(
db: AsyncSession,
*,
days: int | None = None,
) -> dict[str, Any]:
"""Aggregate ASR telemetry from voice co-creation sessions."""
cutoff = datetime.now(timezone.utc) - timedelta(days=days) if days is not None else None
turn_query = (
select(
VoiceTurn,
VoiceSession.user_id,
VoiceSession.final_story_id,
VoiceSession.id,
)
.join(VoiceSession, VoiceTurn.session_id == VoiceSession.id)
.where(
VoiceTurn.user_audio_path.isnot(None),
VoiceTurn.user_transcript.isnot(None),
)
)
failure_query = (
select(
VoiceSessionEvent,
VoiceSession.user_id,
VoiceSession.final_story_id,
VoiceSession.id,
)
.join(VoiceSession, VoiceSessionEvent.session_id == VoiceSession.id)
.where(VoiceSessionEvent.event_type == "turn_transcription_failed")
)
cost_query = select(
CostRecord.user_id,
CostRecord.provider_name,
CostRecord.estimated_cost,
).where(CostRecord.capability == "asr")
if cutoff is not None:
turn_query = turn_query.where(VoiceTurn.created_at >= cutoff)
failure_query = failure_query.where(VoiceSessionEvent.created_at >= cutoff)
cost_query = cost_query.where(CostRecord.timestamp >= cutoff)
turn_rows = (await db.execute(turn_query)).all()
failure_rows = (await db.execute(failure_query)).all()
cost_rows = (await db.execute(cost_query)).all()
costs_by_provider: dict[str, float] = {}
costs_by_user: dict[str, float] = {}
for user_id, provider_name, estimated_cost in cost_rows:
cost = float(estimated_cost or 0.0)
provider = str(provider_name or "unknown")
costs_by_provider[provider] = costs_by_provider.get(provider, 0.0) + cost
costs_by_user[str(user_id)] = costs_by_user.get(str(user_id), 0.0) + cost
provider_buckets: dict[tuple[str, str], dict[str, Any]] = {}
failure_reasons: dict[str, int] = {}
by_user: dict[str, dict[str, Any]] = {}
user_ids: set[str] = set()
story_ids: set[int] = set()
voice_session_ids: set[str] = set()
successful_calls = 0
failed_calls = 0
def provider_bucket(adapter: str) -> dict[str, Any]:
return provider_buckets.setdefault(
("asr", adapter),
{
"capability": "asr",
"adapter": adapter,
"call_count": 0,
"success_count": 0,
"failure_count": 0,
"avg_latency_ms": None,
"estimated_cost_usd": 0.0,
},
)
for turn, user_id, final_story_id, session_id in turn_rows:
user_id = str(user_id)
adapter = _voice_asr_provider_from_turn(turn)
user_ids.add(user_id)
voice_session_ids.add(str(session_id))
if final_story_id is not None:
story_ids.add(int(final_story_id))
bucket = provider_bucket(adapter)
bucket["call_count"] += 1
bucket["success_count"] += 1
successful_calls += 1
user_bucket = by_user.setdefault(user_id, _empty_admin_user_bucket(user_id))
user_bucket["call_count"] += 1
user_bucket["success_count"] += 1
if final_story_id is not None:
user_bucket["story_ids"].add(int(final_story_id))
for provider_name, cost in costs_by_provider.items():
key = ("asr", provider_name)
if key in provider_buckets:
provider_buckets[key]["estimated_cost_usd"] += cost
for user_id, cost in costs_by_user.items():
if user_id in by_user:
by_user[user_id]["estimated_cost_usd"] += cost
for event, user_id, final_story_id, session_id in failure_rows:
metadata = event.event_metadata or {}
adapter = str(
metadata.get("adapter")
or metadata.get("transcription_provider")
or "unknown"
)
user_id = str(user_id)
reason = str(metadata.get("error") or "unknown_error")
user_ids.add(user_id)
voice_session_ids.add(str(session_id))
if final_story_id is not None:
story_ids.add(int(final_story_id))
bucket = provider_bucket(adapter)
bucket["call_count"] += 1
bucket["failure_count"] += 1
failed_calls += 1
failure_reasons[reason] = failure_reasons.get(reason, 0) + 1
user_bucket = by_user.setdefault(user_id, _empty_admin_user_bucket(user_id))
user_bucket["call_count"] += 1
user_bucket["failure_count"] += 1
if final_story_id is not None:
user_bucket["story_ids"].add(int(final_story_id))
by_provider = [
{
**bucket,
"estimated_cost_usd": round(bucket["estimated_cost_usd"], 6),
}
for bucket in provider_buckets.values()
]
by_provider.sort(
key=lambda item: (
str(item["capability"]),
str(item["adapter"]),
)
)
return {
"total_calls": successful_calls + failed_calls,
"successful_calls": successful_calls,
"failed_calls": failed_calls,
"avg_latency_ms": None,
"estimated_cost_usd": round(
sum(float(bucket["estimated_cost_usd"]) for bucket in provider_buckets.values()),
6,
),
"by_provider": by_provider,
"failure_reasons": [
{"reason": reason, "count": count}
for reason, count in sorted(
failure_reasons.items(),
key=lambda item: (-item[1], item[0]),
)
],
"by_user": by_user,
"user_ids": user_ids,
"story_ids": story_ids,
"voice_session_ids": voice_session_ids,
"voice_turn_count": successful_calls,
}
async def get_admin_provider_analytics(
db: AsyncSession,
*,
days: int | None = None,
capability: str | None = None,
) -> dict[str, Any]:
"""Aggregate provider telemetry across every user in the current environment."""
rows = (await db.execute(_provider_events_query(days=days))).all()
events = [event for event, _, _ in rows]
filtered_rows = [
(event, user_id, story_id)
for event, user_id, story_id in rows
if _event_matches_capability(event, capability)
]
by_user: dict[str, dict[str, Any]] = {}
filtered_job_ids = {event.job_id for event, _, _ in filtered_rows}
filtered_story_ids = {
story_id for _, _, story_id in filtered_rows if story_id is not None
}
filtered_user_ids = {user_id for _, user_id, _ in filtered_rows}
for event, user_id, story_id in filtered_rows:
bucket = by_user.setdefault(
user_id,
_empty_admin_user_bucket(user_id),
)
bucket["call_count"] += 1
bucket["job_ids"].add(event.job_id)
if story_id is not None:
bucket["story_ids"].add(story_id)
if event.event_type == "provider_call_succeeded":
bucket["success_count"] += 1
bucket["estimated_cost_usd"] += (
_as_float((event.event_metadata or {}).get("estimated_cost_usd")) or 0.0
)
else:
bucket["failure_count"] += 1
provider_analytics = _aggregate_provider_events(events, capability=capability)
voice_session_count = 0
voice_turn_count = 0
if capability in {None, "asr"}:
asr_analytics = await _aggregate_voice_asr_provider_analytics(db, days=days)
provider_analytics = _merge_provider_analytics(
provider_analytics,
asr_analytics,
)
filtered_user_ids.update(asr_analytics["user_ids"])
filtered_story_ids.update(asr_analytics["story_ids"])
voice_session_count = len(asr_analytics["voice_session_ids"])
voice_turn_count = int(asr_analytics["voice_turn_count"])
for user_id, source_bucket in asr_analytics["by_user"].items():
target_bucket = by_user.setdefault(
user_id,
_empty_admin_user_bucket(user_id),
)
_merge_admin_user_bucket(target_bucket, source_bucket)
return {
"scope": "current_environment",
"window_days": days,
"capability": capability,
**provider_analytics,
"user_count": len(filtered_user_ids),
"job_count": len(filtered_job_ids),
"story_count": len(filtered_story_ids),
"voice_session_count": voice_session_count,
"voice_turn_count": voice_turn_count,
"by_user": _serialize_admin_user_buckets(by_user),
}

View File

@@ -11,7 +11,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
from app.db.models import GenerationJob, GenerationJobEvent, Story
from app.db.models import (
GenerationJob,
GenerationJobEvent,
Story,
)
logger = get_logger(__name__)
@@ -712,87 +716,6 @@ async def get_user_provider_analytics(
}
async def get_admin_provider_analytics(
db: AsyncSession,
*,
days: int | None = None,
capability: str | None = None,
) -> dict[str, Any]:
"""Aggregate provider telemetry across every user in the current environment."""
rows = (await db.execute(_provider_events_query(days=days))).all()
events = [event for event, _, _ in rows]
filtered_rows = [
(event, user_id, story_id)
for event, user_id, story_id in rows
if _event_matches_capability(event, capability)
]
by_user: dict[str, dict[str, Any]] = {}
filtered_job_ids = {event.job_id for event, _, _ in filtered_rows}
filtered_story_ids = {
story_id for _, _, story_id in filtered_rows if story_id is not None
}
filtered_user_ids = {user_id for _, user_id, _ in filtered_rows}
for event, user_id, story_id in filtered_rows:
bucket = by_user.setdefault(
user_id,
{
"user_id": user_id,
"call_count": 0,
"success_count": 0,
"failure_count": 0,
"estimated_cost_usd": 0.0,
"job_ids": set(),
"story_ids": set(),
},
)
bucket["call_count"] += 1
bucket["job_ids"].add(event.job_id)
if story_id is not None:
bucket["story_ids"].add(story_id)
if event.event_type == "provider_call_succeeded":
bucket["success_count"] += 1
bucket["estimated_cost_usd"] += (
_as_float((event.event_metadata or {}).get("estimated_cost_usd")) or 0.0
)
else:
bucket["failure_count"] += 1
serialized_users = [
{
"user_id": user_id,
"call_count": bucket["call_count"],
"success_count": bucket["success_count"],
"failure_count": bucket["failure_count"],
"job_count": len(bucket["job_ids"]),
"story_count": len(bucket["story_ids"]),
"estimated_cost_usd": round(bucket["estimated_cost_usd"], 6),
}
for user_id, bucket in by_user.items()
]
serialized_users.sort(
key=lambda item: (
-int(item["call_count"]),
-float(item["estimated_cost_usd"]),
str(item["user_id"]),
)
)
return {
"scope": "current_environment",
"window_days": days,
"capability": capability,
**_aggregate_provider_events(events, capability=capability),
"user_count": len(filtered_user_ids),
"job_count": len(filtered_job_ids),
"story_count": len(filtered_story_ids),
"by_user": serialized_users,
}
async def get_user_generation_ops_summary(
db: AsyncSession,
*,

View File

@@ -117,6 +117,7 @@ def _get_default_config(adapter_name: str) -> AdapterConfig | None:
if adapter_name == "openai_asr":
return AdapterConfig(
api_key=settings.openai_api_key,
api_base=getattr(settings, "openai_api_base", ""),
model=settings.voice_transcription_model,
timeout_ms=60000,
)
@@ -131,6 +132,7 @@ def _get_default_config(adapter_name: str) -> AdapterConfig | None:
if adapter_name == "openai":
return AdapterConfig(
api_key=getattr(settings, "openai_api_key", ""),
api_base=getattr(settings, "openai_api_base", ""),
model=settings.openai_model,
timeout_ms=60000,
)

View File

@@ -1,12 +1,14 @@
from datetime import datetime, timedelta, timezone
from decimal import Decimal
from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient
from app.api import admin_providers
from app.core.admin_auth import admin_guard
from app.db.admin_models import CostRecord
from app.db.database import get_db
from app.db.models import Story, User
from app.db.models import Story, User, VoiceSession, VoiceSessionEvent, VoiceTurn
from app.services.generation_jobs import create_generation_job, record_generation_event
@@ -286,3 +288,105 @@ async def test_admin_provider_analytics_support_days_and_capability_filters(
response = await client.get("/admin/providers/analytics?capability=unknown")
assert response.status_code == 422
async def test_admin_provider_analytics_includes_voice_asr_calls(
db_session,
test_user,
):
second_user = User(
id="google:asr-user",
name="ASR User",
avatar_url="https://example.com/asr.png",
provider="google",
)
db_session.add(second_user)
await db_session.commit()
successful_session = VoiceSession(user_id=test_user.id, status="active")
failed_session = VoiceSession(user_id=second_user.id, status="active")
db_session.add_all([successful_session, failed_session])
await db_session.commit()
await db_session.refresh(successful_session)
await db_session.refresh(failed_session)
db_session.add_all(
[
VoiceTurn(
session_id=successful_session.id,
turn_index=1,
status="completed",
user_audio_path="/tmp/voice-turn.webm",
user_audio_mime_type="audio/webm",
user_audio_duration_ms=1300,
user_transcript="我想听一个星星故事",
transcript_confidence=0.96,
detected_intent="continue_story",
intent_confidence=0.9,
story_patch={"transcription_provider": "demo"},
),
VoiceSessionEvent(
session_id=failed_session.id,
event_type="turn_transcription_failed",
status="failed",
message="Voice transcription failed.",
event_metadata={"error": "OPENAI_API_KEY 未配置"},
),
CostRecord(
user_id=test_user.id,
provider_name="demo",
capability="asr",
estimated_cost=Decimal("0.002"),
),
]
)
await db_session.commit()
admin_app = _build_admin_test_app(db_session)
transport = ASGITransport(app=admin_app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.get("/admin/providers/analytics?capability=asr")
assert response.status_code == 200
data = response.json()
assert data["capability"] == "asr"
assert data["total_calls"] == 2
assert data["successful_calls"] == 1
assert data["failed_calls"] == 1
assert data["user_count"] == 2
assert data["job_count"] == 0
assert data["story_count"] == 0
assert data["voice_session_count"] == 2
assert data["voice_turn_count"] == 1
assert data["estimated_cost_usd"] == 0.002
assert data["failure_reasons"] == [
{"reason": "OPENAI_API_KEY 未配置", "count": 1}
]
assert data["by_provider"] == [
{
"capability": "asr",
"adapter": "demo",
"call_count": 1,
"success_count": 1,
"failure_count": 0,
"avg_latency_ms": None,
"estimated_cost_usd": 0.002,
},
{
"capability": "asr",
"adapter": "unknown",
"call_count": 1,
"success_count": 0,
"failure_count": 1,
"avg_latency_ms": None,
"estimated_cost_usd": 0.0,
},
]
users = {row["user_id"]: row for row in data["by_user"]}
assert users[test_user.id]["call_count"] == 1
assert users[test_user.id]["success_count"] == 1
assert users[test_user.id]["estimated_cost_usd"] == 0.002
assert users[second_user.id]["call_count"] == 1
assert users[second_user.id]["failure_count"] == 1

View File

@@ -73,6 +73,7 @@ class TestDevSigninRedirect:
def test_dev_signin_uses_allowed_next_url(self, client: TestClient, monkeypatch):
"""允许的 next 参数应作为登录完成后的回跳地址。"""
monkeypatch.setattr(settings, "debug", True)
monkeypatch.setattr(settings, "cors_origins", ["http://localhost:5173", "http://localhost:5174"])
response = client.get(
@@ -86,6 +87,7 @@ class TestDevSigninRedirect:
def test_dev_signin_rejects_untrusted_next_url(self, client: TestClient, monkeypatch):
"""不可信的 next 参数应回退到默认前端地址,避免开放重定向。"""
monkeypatch.setattr(settings, "debug", True)
monkeypatch.setattr(settings, "cors_origins", ["http://localhost:5173", "http://localhost:5174"])
response = client.get(

View File

@@ -0,0 +1,53 @@
"""配置加载约定测试。"""
from pathlib import Path
from app.core.config import BACKEND_ENV_FILE, Settings
def test_default_env_file_is_backend_env():
"""默认 env 文件应固定为 backend/.env 的绝对路径。"""
configured_env_file = Path(Settings.model_config["env_file"])
assert configured_env_file == BACKEND_ENV_FILE
assert configured_env_file.is_absolute()
assert configured_env_file.parent.name == "backend"
assert configured_env_file.name == ".env"
def test_explicit_env_file_ignores_current_working_directory_dotenv(monkeypatch, tmp_path):
"""显式 env 文件不应被当前目录 .env 污染。"""
root_env = tmp_path / ".env"
root_env.write_text(
"\n".join(
[
"SECRET_KEY=root-env-should-not-be-used",
"DATABASE_URL=sqlite+aiosqlite:///root-env.db",
"DEBUG=false",
]
),
encoding="utf-8",
)
backend_env = tmp_path / "backend.env"
backend_env.write_text(
"\n".join(
[
"SECRET_KEY=backend-env-secret",
"DATABASE_URL=sqlite+aiosqlite:///backend-env.db",
"DEBUG=true",
]
),
encoding="utf-8",
)
monkeypatch.chdir(tmp_path)
monkeypatch.delenv("SECRET_KEY", raising=False)
monkeypatch.delenv("DATABASE_URL", raising=False)
settings = Settings(_env_file=backend_env)
assert settings.database_url == "sqlite+aiosqlite:///backend-env.db"
assert settings.secret_key == "backend-env-secret"
assert settings.debug is True

View File

@@ -299,6 +299,21 @@ class TestProviderPolicy:
assert result.transcript_text == "我想听一个小熊找星星的故事"
assert result.confidence == 1.0
assert result.provider == "demo"
def test_openai_asr_default_config_uses_openai_env(self):
from app.services.provider_router import _get_default_config
with patch("app.services.provider_router.settings") as mock_settings:
mock_settings.openai_api_key = "openai-key"
mock_settings.openai_api_base = "https://api.example.com/v1"
mock_settings.voice_transcription_model = "gpt-4o-mini-transcribe"
config = _get_default_config("openai_asr")
assert config is not None
assert config.api_key == "openai-key"
assert config.api_base == "https://api.example.com/v1"
assert config.model == "gpt-4o-mini-transcribe"
class TestProviderConfigFromDB: