Implement unified story generation flow
This commit is contained in:
@@ -1,12 +1,14 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from decimal import Decimal
|
||||
|
||||
from fastapi import FastAPI
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from app.api import admin_providers
|
||||
from app.core.admin_auth import admin_guard
|
||||
from app.db.admin_models import CostRecord
|
||||
from app.db.database import get_db
|
||||
from app.db.models import Story, User
|
||||
from app.db.models import Story, User, VoiceSession, VoiceSessionEvent, VoiceTurn
|
||||
from app.services.generation_jobs import create_generation_job, record_generation_event
|
||||
|
||||
|
||||
@@ -286,3 +288,105 @@ async def test_admin_provider_analytics_support_days_and_capability_filters(
|
||||
|
||||
response = await client.get("/admin/providers/analytics?capability=unknown")
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
async def test_admin_provider_analytics_includes_voice_asr_calls(
|
||||
db_session,
|
||||
test_user,
|
||||
):
|
||||
second_user = User(
|
||||
id="google:asr-user",
|
||||
name="ASR User",
|
||||
avatar_url="https://example.com/asr.png",
|
||||
provider="google",
|
||||
)
|
||||
db_session.add(second_user)
|
||||
await db_session.commit()
|
||||
|
||||
successful_session = VoiceSession(user_id=test_user.id, status="active")
|
||||
failed_session = VoiceSession(user_id=second_user.id, status="active")
|
||||
db_session.add_all([successful_session, failed_session])
|
||||
await db_session.commit()
|
||||
await db_session.refresh(successful_session)
|
||||
await db_session.refresh(failed_session)
|
||||
|
||||
db_session.add_all(
|
||||
[
|
||||
VoiceTurn(
|
||||
session_id=successful_session.id,
|
||||
turn_index=1,
|
||||
status="completed",
|
||||
user_audio_path="/tmp/voice-turn.webm",
|
||||
user_audio_mime_type="audio/webm",
|
||||
user_audio_duration_ms=1300,
|
||||
user_transcript="我想听一个星星故事",
|
||||
transcript_confidence=0.96,
|
||||
detected_intent="continue_story",
|
||||
intent_confidence=0.9,
|
||||
story_patch={"transcription_provider": "demo"},
|
||||
),
|
||||
VoiceSessionEvent(
|
||||
session_id=failed_session.id,
|
||||
event_type="turn_transcription_failed",
|
||||
status="failed",
|
||||
message="Voice transcription failed.",
|
||||
event_metadata={"error": "OPENAI_API_KEY 未配置"},
|
||||
),
|
||||
CostRecord(
|
||||
user_id=test_user.id,
|
||||
provider_name="demo",
|
||||
capability="asr",
|
||||
estimated_cost=Decimal("0.002"),
|
||||
),
|
||||
]
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
admin_app = _build_admin_test_app(db_session)
|
||||
transport = ASGITransport(app=admin_app)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.get("/admin/providers/analytics?capability=asr")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["capability"] == "asr"
|
||||
assert data["total_calls"] == 2
|
||||
assert data["successful_calls"] == 1
|
||||
assert data["failed_calls"] == 1
|
||||
assert data["user_count"] == 2
|
||||
assert data["job_count"] == 0
|
||||
assert data["story_count"] == 0
|
||||
assert data["voice_session_count"] == 2
|
||||
assert data["voice_turn_count"] == 1
|
||||
assert data["estimated_cost_usd"] == 0.002
|
||||
assert data["failure_reasons"] == [
|
||||
{"reason": "OPENAI_API_KEY 未配置", "count": 1}
|
||||
]
|
||||
assert data["by_provider"] == [
|
||||
{
|
||||
"capability": "asr",
|
||||
"adapter": "demo",
|
||||
"call_count": 1,
|
||||
"success_count": 1,
|
||||
"failure_count": 0,
|
||||
"avg_latency_ms": None,
|
||||
"estimated_cost_usd": 0.002,
|
||||
},
|
||||
{
|
||||
"capability": "asr",
|
||||
"adapter": "unknown",
|
||||
"call_count": 1,
|
||||
"success_count": 0,
|
||||
"failure_count": 1,
|
||||
"avg_latency_ms": None,
|
||||
"estimated_cost_usd": 0.0,
|
||||
},
|
||||
]
|
||||
|
||||
users = {row["user_id"]: row for row in data["by_user"]}
|
||||
assert users[test_user.id]["call_count"] == 1
|
||||
assert users[test_user.id]["success_count"] == 1
|
||||
assert users[test_user.id]["estimated_cost_usd"] == 0.002
|
||||
assert users[second_user.id]["call_count"] == 1
|
||||
assert users[second_user.id]["failure_count"] == 1
|
||||
|
||||
@@ -73,6 +73,7 @@ class TestDevSigninRedirect:
|
||||
|
||||
def test_dev_signin_uses_allowed_next_url(self, client: TestClient, monkeypatch):
|
||||
"""允许的 next 参数应作为登录完成后的回跳地址。"""
|
||||
monkeypatch.setattr(settings, "debug", True)
|
||||
monkeypatch.setattr(settings, "cors_origins", ["http://localhost:5173", "http://localhost:5174"])
|
||||
|
||||
response = client.get(
|
||||
@@ -86,6 +87,7 @@ class TestDevSigninRedirect:
|
||||
|
||||
def test_dev_signin_rejects_untrusted_next_url(self, client: TestClient, monkeypatch):
|
||||
"""不可信的 next 参数应回退到默认前端地址,避免开放重定向。"""
|
||||
monkeypatch.setattr(settings, "debug", True)
|
||||
monkeypatch.setattr(settings, "cors_origins", ["http://localhost:5173", "http://localhost:5174"])
|
||||
|
||||
response = client.get(
|
||||
|
||||
53
backend/tests/test_config.py
Normal file
53
backend/tests/test_config.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""配置加载约定测试。"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.config import BACKEND_ENV_FILE, Settings
|
||||
|
||||
|
||||
def test_default_env_file_is_backend_env():
|
||||
"""默认 env 文件应固定为 backend/.env 的绝对路径。"""
|
||||
|
||||
configured_env_file = Path(Settings.model_config["env_file"])
|
||||
|
||||
assert configured_env_file == BACKEND_ENV_FILE
|
||||
assert configured_env_file.is_absolute()
|
||||
assert configured_env_file.parent.name == "backend"
|
||||
assert configured_env_file.name == ".env"
|
||||
|
||||
|
||||
def test_explicit_env_file_ignores_current_working_directory_dotenv(monkeypatch, tmp_path):
|
||||
"""显式 env 文件不应被当前目录 .env 污染。"""
|
||||
|
||||
root_env = tmp_path / ".env"
|
||||
root_env.write_text(
|
||||
"\n".join(
|
||||
[
|
||||
"SECRET_KEY=root-env-should-not-be-used",
|
||||
"DATABASE_URL=sqlite+aiosqlite:///root-env.db",
|
||||
"DEBUG=false",
|
||||
]
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
backend_env = tmp_path / "backend.env"
|
||||
backend_env.write_text(
|
||||
"\n".join(
|
||||
[
|
||||
"SECRET_KEY=backend-env-secret",
|
||||
"DATABASE_URL=sqlite+aiosqlite:///backend-env.db",
|
||||
"DEBUG=true",
|
||||
]
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.chdir(tmp_path)
|
||||
monkeypatch.delenv("SECRET_KEY", raising=False)
|
||||
monkeypatch.delenv("DATABASE_URL", raising=False)
|
||||
|
||||
settings = Settings(_env_file=backend_env)
|
||||
|
||||
assert settings.database_url == "sqlite+aiosqlite:///backend-env.db"
|
||||
assert settings.secret_key == "backend-env-secret"
|
||||
assert settings.debug is True
|
||||
@@ -299,6 +299,21 @@ class TestProviderPolicy:
|
||||
assert result.transcript_text == "我想听一个小熊找星星的故事"
|
||||
assert result.confidence == 1.0
|
||||
assert result.provider == "demo"
|
||||
|
||||
def test_openai_asr_default_config_uses_openai_env(self):
|
||||
from app.services.provider_router import _get_default_config
|
||||
|
||||
with patch("app.services.provider_router.settings") as mock_settings:
|
||||
mock_settings.openai_api_key = "openai-key"
|
||||
mock_settings.openai_api_base = "https://api.example.com/v1"
|
||||
mock_settings.voice_transcription_model = "gpt-4o-mini-transcribe"
|
||||
|
||||
config = _get_default_config("openai_asr")
|
||||
|
||||
assert config is not None
|
||||
assert config.api_key == "openai-key"
|
||||
assert config.api_base == "https://api.example.com/v1"
|
||||
assert config.model == "gpt-4o-mini-transcribe"
|
||||
|
||||
|
||||
class TestProviderConfigFromDB:
|
||||
|
||||
Reference in New Issue
Block a user