feat: improve voice studio alpha recovery flow

This commit is contained in:
2026-04-19 23:25:41 +08:00
parent 46d6201529
commit 4ecf0c09c0
9 changed files with 657 additions and 14 deletions

View File

@@ -12,6 +12,7 @@ from fastapi import (
)
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.deps import require_user
from app.core.rate_limiter import check_rate_limit
from app.db.database import get_db
@@ -34,11 +35,14 @@ from app.services.voice_session_service import (
create_voice_turn_from_text_service,
create_voice_turn_from_upload_service,
finalize_voice_session_service,
get_latest_active_voice_session_service,
get_voice_session_detail_service,
get_voice_turn_audio_service,
get_voice_turn_service,
get_voice_turn_user_audio_service,
list_voice_sessions_service,
retry_voice_turn_audio_service,
retry_voice_turn_service,
)
router = APIRouter()
@@ -68,8 +72,13 @@ async def create_voice_session(
@router.get("/voice-sessions", response_model=list[VoiceSessionSummaryResponse])
async def list_voice_sessions(
limit: int = Query(default=8, ge=1, le=20),
limit: int = Query(
default=settings.voice_session_default_list_limit,
ge=1,
le=settings.voice_session_max_list_limit,
),
active_only: bool = Query(default=False),
active_first: bool = Query(default=True),
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
@@ -79,9 +88,19 @@ async def list_voice_sessions(
db,
limit=limit,
active_only=active_only,
active_first=active_first,
)
@router.get("/voice-sessions/active", response_model=VoiceSessionSummaryResponse | None)
async def get_latest_active_voice_session(
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Get the latest active voice session for quick resume behavior."""
return await get_latest_active_voice_session_service(user.id, db)
@router.get("/voice-sessions/{session_id}", response_model=VoiceSessionDetailResponse)
async def get_voice_session(
session_id: str,
@@ -158,6 +177,21 @@ async def get_voice_turn(
return await get_voice_turn_service(session_id, turn_id, user.id, db)
@router.post(
"/voice-sessions/{session_id}/turns/{turn_id}/retry",
response_model=VoiceTurnAcceptedResponse,
status_code=status.HTTP_202_ACCEPTED,
)
async def retry_voice_turn(
session_id: str,
turn_id: str,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Retry one failed voice turn using its saved transcript."""
return await retry_voice_turn_service(session_id, turn_id, user.id, db)
@router.get("/voice-sessions/{session_id}/turns/{turn_id}/audio")
async def get_voice_turn_audio(
session_id: str,
@@ -170,6 +204,20 @@ async def get_voice_turn_audio(
return Response(content=audio_bytes, media_type="audio/mpeg")
@router.post(
"/voice-sessions/{session_id}/turns/{turn_id}/retry-audio",
response_model=VoiceTurnSummaryResponse,
)
async def retry_voice_turn_audio(
session_id: str,
turn_id: str,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Retry assistant audio synthesis when one turn only has text output."""
return await retry_voice_turn_audio_service(session_id, turn_id, user.id, db)
@router.get("/voice-sessions/{session_id}/turns/{turn_id}/user-audio")
async def get_voice_turn_user_audio(
session_id: str,

View File

@@ -82,6 +82,18 @@ class Settings(BaseSettings):
"zh",
description="Preferred language hint for voice transcription",
)
voice_turn_max_upload_bytes: int = Field(
5 * 1024 * 1024,
description="Maximum accepted upload size in bytes for one voice turn audio file",
)
voice_session_default_list_limit: int = Field(
8,
description="Default number of recent voice sessions returned to the client",
)
voice_session_max_list_limit: int = Field(
20,
description="Maximum number of recent voice sessions returned to the client",
)
story_audio_cache_ttl_days: int = Field(
30,
description="TTL in days before cached story audio is pruned",

View File

@@ -101,6 +101,7 @@ class VoiceSessionSummaryResponse(BaseModel):
latest_detected_intent: str | None = None
latest_assistant_audio_ready: bool = False
last_turn_status: str | None = None
transcription_mode_hint: str | None = None
can_continue: bool = False
can_finalize: bool = False
last_error: str | None = None

View File

@@ -6,9 +6,10 @@ from datetime import datetime, timezone
from typing import Any
from fastapi import HTTPException
from sqlalchemy import desc, select
from sqlalchemy import case, desc, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
from app.db.models import VoiceSession, VoiceSessionEvent, VoiceTurn
from app.schemas.voice_session_schemas import (
@@ -134,6 +135,7 @@ def _session_to_summary(
session_audio_exists(latest_turn.assistant_audio_path) if latest_turn else False
),
last_turn_status=latest_turn.status if latest_turn else None,
transcription_mode_hint=settings.voice_transcription_mode,
can_continue=_session_can_continue(session),
can_finalize=_session_can_finalize(session),
last_error=session.last_error,
@@ -602,17 +604,29 @@ async def list_voice_sessions_service(
user_id: str,
db: AsyncSession,
*,
limit: int = 8,
limit: int | None = None,
active_only: bool = False,
active_first: bool = False,
) -> list[VoiceSessionSummaryResponse]:
query = (
select(VoiceSession)
.where(VoiceSession.user_id == user_id)
.order_by(desc(VoiceSession.updated_at), desc(VoiceSession.created_at))
.limit(limit)
)
resolved_limit = limit or settings.voice_session_default_list_limit
resolved_limit = max(1, min(resolved_limit, settings.voice_session_max_list_limit))
query = select(VoiceSession).where(VoiceSession.user_id == user_id)
if active_only:
query = query.where(VoiceSession.status.in_(CONTINUABLE_SESSION_STATUSES))
if active_first:
query = query.order_by(
desc(
case(
(VoiceSession.status.in_(CONTINUABLE_SESSION_STATUSES), 1),
else_=0,
)
),
desc(VoiceSession.updated_at),
desc(VoiceSession.created_at),
)
else:
query = query.order_by(desc(VoiceSession.updated_at), desc(VoiceSession.created_at))
query = query.limit(resolved_limit)
sessions = (await db.execute(query)).scalars().all()
summaries: list[VoiceSessionSummaryResponse] = []
@@ -628,6 +642,30 @@ async def list_voice_sessions_service(
return summaries
async def get_latest_active_voice_session_service(
user_id: str,
db: AsyncSession,
) -> VoiceSessionSummaryResponse | None:
query = (
select(VoiceSession)
.where(
VoiceSession.user_id == user_id,
VoiceSession.status.in_(CONTINUABLE_SESSION_STATUSES),
)
.order_by(desc(VoiceSession.updated_at), desc(VoiceSession.created_at))
.limit(1)
)
session = (await db.execute(query)).scalar_one_or_none()
if session is None:
return None
latest_turn = await _get_latest_turn(db, session_id=session.id)
return _session_to_summary(
session,
latest_turn=latest_turn,
total_turns=session.current_turn_index,
)
async def create_voice_session_service(
request: VoiceSessionCreateRequest,
user_id: str,
@@ -766,6 +804,13 @@ async def create_voice_turn_from_upload_service(
status_code=409,
detail="Voice session is not ready for another turn.",
)
if not audio_bytes:
raise HTTPException(status_code=400, detail="上传音频为空,请重新录音后再试。")
if len(audio_bytes) > settings.voice_turn_max_upload_bytes:
raise HTTPException(
status_code=413,
detail="上传音频过大,请缩短单轮录音时长后再试。",
)
next_turn_index = session.current_turn_index + 1
user_audio_path = write_uploaded_user_audio(
session_id=session.id,
@@ -805,6 +850,91 @@ async def create_voice_turn_from_upload_service(
)
async def retry_voice_turn_service(
session_id: str,
turn_id: str,
user_id: str,
db: AsyncSession,
) -> VoiceTurnAcceptedResponse:
turn = await _get_owned_turn(
db,
session_id=session_id,
turn_id=turn_id,
user_id=user_id,
)
if turn.status != "failed":
raise HTTPException(status_code=409, detail="Only failed turns can be retried.")
if not turn.user_transcript:
raise HTTPException(status_code=409, detail="This turn has no transcript to retry.")
return await create_voice_turn_from_text_service(
session_id,
VoiceTurnCreateFallbackRequest(
transcript_text=turn.user_transcript,
duration_ms=turn.user_audio_duration_ms,
),
user_id,
db,
)
async def retry_voice_turn_audio_service(
session_id: str,
turn_id: str,
user_id: str,
db: AsyncSession,
) -> VoiceTurnSummaryResponse:
turn = await _get_owned_turn(
db,
session_id=session_id,
turn_id=turn_id,
user_id=user_id,
)
if not turn.assistant_text:
raise HTTPException(status_code=409, detail="This turn has no assistant text to speak.")
if session_audio_exists(turn.assistant_audio_path):
raise HTTPException(status_code=409, detail="Assistant audio already exists for this turn.")
try:
audio_bytes = await text_to_speech(
turn.assistant_text,
db=db,
user_id=user_id,
)
saved_path = write_session_audio(
build_turn_assistant_audio_path(turn.session_id, turn.turn_index),
audio_bytes,
)
turn.assistant_audio_path = saved_path
turn.assistant_audio_duration_ms = None
if turn.status == "narrative_ready":
turn.status = "audio_ready"
await db.commit()
await db.refresh(turn)
await _record_session_event(
db,
session_id=turn.session_id,
turn_id=turn.id,
event_type="assistant_audio_retry_succeeded",
status="succeeded",
message="Assistant audio regenerated for one voice turn.",
metadata={"audio_path": saved_path},
)
except Exception as exc:
await _record_session_event(
db,
session_id=turn.session_id,
turn_id=turn.id,
event_type="assistant_audio_retry_failed",
status="failed",
message="Assistant audio retry failed.",
metadata={"error": str(exc)},
)
raise HTTPException(status_code=503, detail="语音补发失败,请稍后再试。") from exc
return _turn_to_summary(turn)
async def get_voice_turn_service(
session_id: str,
turn_id: str,

View File

@@ -2,6 +2,7 @@ from unittest.mock import AsyncMock, patch
from httpx import ASGITransport, AsyncClient
from app.core.config import settings
from app.db.database import get_db
from app.main import app
from app.services.adapters.text.models import StoryOutput
@@ -343,3 +344,221 @@ async def test_voice_session_list_orders_recent_sessions_first(
}
finally:
app.dependency_overrides.clear()
async def test_voice_session_active_endpoint_returns_latest_active_session(
db_session,
auth_token,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
with (
patch(
"app.services.voice_session_service.generate_story_content",
new_callable=AsyncMock,
) as mock_generate,
patch(
"app.services.voice_session_service.text_to_speech",
new_callable=AsyncMock,
) as mock_tts,
):
mock_generate.return_value = StoryOutput(
mode="generated",
title="活动会话",
story_text="一段活动中的故事。",
cover_prompt_suggestion="活动会话封面",
)
mock_tts.return_value = b"active-audio"
transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post("/api/voice-sessions", json={})
old_session_id = response.json()["id"]
await client.post(
f"/api/voice-sessions/{old_session_id}/abandon",
json={"reason": "旧会话结束"},
)
response = await client.post("/api/voice-sessions", json={})
active_session_id = response.json()["id"]
await client.post(
f"/api/voice-sessions/{active_session_id}/turns/fallback",
json={"transcript_text": "请继续一个新故事"},
)
response = await client.get("/api/voice-sessions/active")
assert response.status_code == 200
data = response.json()
assert data["id"] == active_session_id
assert data["can_continue"] is True
assert data["status"] == "waiting_user"
finally:
app.dependency_overrides.clear()
async def test_voice_session_can_retry_failed_turn_from_saved_transcript(
db_session,
auth_token,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
with (
patch(
"app.services.voice_session_service.generate_story_content",
new_callable=AsyncMock,
) as mock_generate,
patch(
"app.services.voice_session_service.text_to_speech",
new_callable=AsyncMock,
) as mock_tts,
):
mock_generate.side_effect = [
RuntimeError("provider down"),
StoryOutput(
mode="generated",
title="重试成功",
story_text="重试后的故事终于顺利继续了。",
cover_prompt_suggestion="重试封面",
),
]
mock_tts.return_value = b"retry-turn-audio"
transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post("/api/voice-sessions", json={})
session_id = response.json()["id"]
response = await client.post(
f"/api/voice-sessions/{session_id}/turns/fallback",
json={"transcript_text": "先讲一个会失败的回合"},
)
assert response.status_code == 202
failed_turn_id = response.json()["turn_id"]
response = await client.get(
f"/api/voice-sessions/{session_id}/turns/{failed_turn_id}"
)
assert response.status_code == 200
assert response.json()["status"] == "failed"
response = await client.post(
f"/api/voice-sessions/{session_id}/turns/{failed_turn_id}/retry"
)
assert response.status_code == 202
retried_turn_id = response.json()["turn_id"]
assert retried_turn_id != failed_turn_id
response = await client.get(
f"/api/voice-sessions/{session_id}/turns/{retried_turn_id}"
)
assert response.status_code == 200
retried_turn = response.json()
assert retried_turn["status"] == "audio_ready"
assert retried_turn["assistant_text"] == "重试后的故事终于顺利继续了。"
finally:
app.dependency_overrides.clear()
async def test_voice_session_can_retry_missing_assistant_audio(
db_session,
auth_token,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
with (
patch(
"app.services.voice_session_service.generate_story_content",
new_callable=AsyncMock,
) as mock_generate,
patch(
"app.services.voice_session_service.text_to_speech",
new_callable=AsyncMock,
) as mock_tts,
):
mock_generate.return_value = StoryOutput(
mode="generated",
title="补发语音",
story_text="这一轮先有文本,稍后再补发语音。",
cover_prompt_suggestion="补发语音封面",
)
mock_tts.side_effect = [RuntimeError("tts down"), b"recovered-audio"]
transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post("/api/voice-sessions", json={})
session_id = response.json()["id"]
response = await client.post(
f"/api/voice-sessions/{session_id}/turns/fallback",
json={"transcript_text": "请先给我一段只有文本的结果"},
)
assert response.status_code == 202
turn_id = response.json()["turn_id"]
response = await client.get(
f"/api/voice-sessions/{session_id}/turns/{turn_id}"
)
assert response.status_code == 200
turn = response.json()
assert turn["status"] == "narrative_ready"
assert turn["assistant_audio_ready"] is False
response = await client.post(
f"/api/voice-sessions/{session_id}/turns/{turn_id}/retry-audio"
)
assert response.status_code == 200
retried = response.json()
assert retried["status"] == "audio_ready"
assert retried["assistant_audio_ready"] is True
finally:
app.dependency_overrides.clear()
async def test_voice_session_uploaded_audio_respects_size_limit(
db_session,
auth_token,
monkeypatch,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
monkeypatch.setattr(settings, "voice_turn_max_upload_bytes", 4)
transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post("/api/voice-sessions", json={})
assert response.status_code == 201
session_id = response.json()["id"]
response = await client.post(
f"/api/voice-sessions/{session_id}/turns",
files={
"audio_file": ("turn.webm", b"12345", "audio/webm"),
},
data={"transcript_hint": "太长了"},
)
assert response.status_code == 413
finally:
app.dependency_overrides.clear()