feat: refine voice studio attention workflow

This commit is contained in:
2026-04-21 14:19:51 +08:00
parent 8b50674d04
commit 9f74a93274
7 changed files with 1025 additions and 48 deletions

View File

@@ -1,5 +1,7 @@
"""Voice co-creation session APIs."""
from typing import Literal
from fastapi import (
APIRouter,
Depends,
@@ -82,6 +84,10 @@ async def list_voice_sessions(
le=settings.voice_session_max_list_limit,
),
active_only: bool = Query(default=False),
needs_attention: bool = Query(default=False),
attention_reason: (
Literal["pending_confirmation", "safety_intervention", "failed_turn"] | None
) = Query(default=None),
active_first: bool = Query(default=True),
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
@@ -92,6 +98,8 @@ async def list_voice_sessions(
db,
limit=limit,
active_only=active_only,
needs_attention=needs_attention,
attention_reason=attention_reason,
active_first=active_first,
)

View File

@@ -121,6 +121,7 @@ class VoiceSessionSummaryResponse(BaseModel):
latest_safety_message: str | None = None
latest_assistant_audio_ready: bool = False
last_turn_status: str | None = None
attention_reasons: list[str] = Field(default_factory=list)
transcription_mode_hint: str | None = None
can_continue: bool = False
can_finalize: bool = False
@@ -149,6 +150,10 @@ class VoiceSessionAnalyticsResponse(BaseModel):
window_days: int | None = None
total_sessions: int = 0
attention_sessions: int = 0
confirmation_attention_sessions: int = 0
safety_attention_sessions: int = 0
failed_attention_sessions: int = 0
active_sessions: int = 0
finalized_sessions: int = 0
abandoned_sessions: int = 0

View File

@@ -388,6 +388,12 @@ def _session_to_summary(
story_patch=latest_turn.story_patch or {},
)
latest_safety_state = _resolve_turn_safety_state(latest_turn.story_patch or {})
attention_reasons = _build_session_attention_reasons(
latest_requires_confirmation=latest_confirmation_state["requires_confirmation"],
latest_safety_flags=latest_safety_state["safety_flags"],
last_turn_status=latest_turn.status if latest_turn else None,
last_error=session.last_error,
)
return VoiceSessionSummaryResponse(
id=session.id,
@@ -413,12 +419,55 @@ def _session_to_summary(
session_audio_exists(latest_turn.assistant_audio_path) if latest_turn else False
),
last_turn_status=latest_turn.status if latest_turn else None,
attention_reasons=attention_reasons,
transcription_mode_hint=settings.voice_transcription_mode,
can_continue=_session_can_continue(session),
can_finalize=_can_finalize_with_latest_turn(session, latest_turn),
last_error=session.last_error,
created_at=session.created_at,
updated_at=session.updated_at,
)
def _build_session_attention_reasons(
*,
latest_requires_confirmation: bool,
latest_safety_flags: list[str] | None,
last_turn_status: str | None,
last_error: str | None,
) -> list[str]:
reasons: list[str] = []
if latest_requires_confirmation:
reasons.append("pending_confirmation")
if latest_safety_flags:
reasons.append("safety_intervention")
if last_turn_status == "failed" or last_error:
reasons.append("failed_turn")
return reasons
def _session_summary_needs_attention(summary: VoiceSessionSummaryResponse) -> bool:
return bool(summary.attention_reasons)
def _session_summary_matches_attention_reason(
summary: VoiceSessionSummaryResponse,
attention_reason: str | None,
) -> bool:
if attention_reason is None:
return True
return attention_reason in summary.attention_reasons
async def _build_session_summary(
db: AsyncSession,
session: VoiceSession,
) -> VoiceSessionSummaryResponse:
latest_turn = await _get_latest_turn(db, session_id=session.id)
return _session_to_summary(
session,
latest_turn=latest_turn,
total_turns=session.current_turn_index,
)
@@ -1082,6 +1131,8 @@ async def list_voice_sessions_service(
*,
limit: int | None = None,
active_only: bool = False,
needs_attention: bool = False,
attention_reason: str | None = None,
active_first: bool = False,
) -> list[VoiceSessionSummaryResponse]:
resolved_limit = limit or settings.voice_session_default_list_limit
@@ -1102,19 +1153,20 @@ async def list_voice_sessions_service(
)
else:
query = query.order_by(desc(VoiceSession.updated_at), desc(VoiceSession.created_at))
query = query.limit(resolved_limit)
if not needs_attention and attention_reason is None:
query = query.limit(resolved_limit)
sessions = (await db.execute(query)).scalars().all()
summaries: list[VoiceSessionSummaryResponse] = []
for session in sessions:
latest_turn = await _get_latest_turn(db, session_id=session.id)
summaries.append(
_session_to_summary(
session,
latest_turn=latest_turn,
total_turns=session.current_turn_index,
)
)
summary = await _build_session_summary(db, session)
if needs_attention and not _session_summary_needs_attention(summary):
continue
if not _session_summary_matches_attention_reason(summary, attention_reason):
continue
summaries.append(summary)
if (needs_attention or attention_reason is not None) and len(summaries) >= resolved_limit:
break
return summaries
@@ -1134,12 +1186,7 @@ async def get_latest_active_voice_session_service(
session = (await db.execute(query)).scalar_one_or_none()
if session is None:
return None
latest_turn = await _get_latest_turn(db, session_id=session.id)
return _session_to_summary(
session,
latest_turn=latest_turn,
total_turns=session.current_turn_index,
)
return await _build_session_summary(db, session)
async def get_voice_session_analytics_service(
@@ -1172,8 +1219,25 @@ async def get_voice_session_analytics_service(
sessions = (await db.execute(session_query)).scalars().all()
turns = (await db.execute(turn_query)).scalars().all()
events = (await db.execute(event_query)).scalars().all()
session_summaries = [await _build_session_summary(db, session) for session in sessions]
total_sessions = len(sessions)
attention_sessions = sum(
1 for summary in session_summaries if _session_summary_needs_attention(summary)
)
confirmation_attention_sessions = sum(
1
for summary in session_summaries
if "pending_confirmation" in summary.attention_reasons
)
safety_attention_sessions = sum(
1
for summary in session_summaries
if "safety_intervention" in summary.attention_reasons
)
failed_attention_sessions = sum(
1 for summary in session_summaries if "failed_turn" in summary.attention_reasons
)
active_sessions = sum(
1 for session in sessions if session.status in CONTINUABLE_SESSION_STATUSES
)
@@ -1205,6 +1269,10 @@ async def get_voice_session_analytics_service(
return VoiceSessionAnalyticsResponse(
window_days=days,
total_sessions=total_sessions,
attention_sessions=attention_sessions,
confirmation_attention_sessions=confirmation_attention_sessions,
safety_attention_sessions=safety_attention_sessions,
failed_attention_sessions=failed_attention_sessions,
active_sessions=active_sessions,
finalized_sessions=finalized_sessions,
abandoned_sessions=abandoned_sessions,

View File

@@ -681,6 +681,149 @@ async def test_voice_session_analytics_summarize_failures_and_confirmations(
app.dependency_overrides.clear()
async def test_voice_session_attention_filter_and_analytics_count(
db_session,
auth_token,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
with (
patch(
"app.services.voice_session_service.generate_story_content",
new_callable=AsyncMock,
) as mock_generate,
patch(
"app.services.voice_session_service.text_to_speech",
new_callable=AsyncMock,
) as mock_tts,
patch(
"app.services.voice_session_service.transcribe_voice_audio",
new_callable=AsyncMock,
) as mock_transcribe,
):
mock_generate.side_effect = [
StoryOutput(
mode="generated",
title="正常故事",
story_text="第一段温暖故事。",
cover_prompt_suggestion="normal cover",
),
RuntimeError("provider down"),
]
mock_tts.side_effect = [
b"normal-audio",
b"confirmation-audio",
b"safety-audio",
]
mock_transcribe.return_value = VoiceTranscriptionResult(
transcript_text="我想听一个会发光的小恐龙故事",
confidence=0.41,
provider="openai",
)
transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post("/api/voice-sessions", json={})
normal_session_id = response.json()["id"]
response = await client.post(
f"/api/voice-sessions/{normal_session_id}/turns/fallback",
json={"transcript_text": "先讲一个温暖的普通故事"},
)
assert response.status_code == 202
response = await client.post("/api/voice-sessions", json={})
failed_session_id = response.json()["id"]
response = await client.post(
f"/api/voice-sessions/{failed_session_id}/turns/fallback",
json={"transcript_text": "这轮会触发 provider 异常"},
)
assert response.status_code == 202
response = await client.post("/api/voice-sessions", json={})
confirmation_session_id = response.json()["id"]
response = await client.post(
f"/api/voice-sessions/{confirmation_session_id}/turns",
files={
"audio_file": ("turn.webm", b"fake-webm-audio", "audio/webm"),
},
)
assert response.status_code == 202
response = await client.post("/api/voice-sessions", json={})
safety_session_id = response.json()["id"]
response = await client.post(
f"/api/voice-sessions/{safety_session_id}/turns/fallback",
json={"transcript_text": "我想听一个拿着炸弹互相打的故事"},
)
assert response.status_code == 202
response = await client.get(
"/api/voice-sessions?needs_attention=true&limit=8"
)
assert response.status_code == 200
attention_sessions = response.json()
attention_session_ids = {item["id"] for item in attention_sessions}
assert attention_session_ids == {
failed_session_id,
confirmation_session_id,
safety_session_id,
}
assert normal_session_id not in attention_session_ids
attention_reason_sets = {
item["id"]: set(item["attention_reasons"]) for item in attention_sessions
}
assert attention_reason_sets[confirmation_session_id] == {
"pending_confirmation"
}
assert attention_reason_sets[safety_session_id] == {
"safety_intervention"
}
assert attention_reason_sets[failed_session_id] == {"failed_turn"}
response = await client.get(
"/api/voice-sessions?needs_attention=true&attention_reason=pending_confirmation"
)
assert response.status_code == 200
confirmation_sessions = response.json()
assert [item["id"] for item in confirmation_sessions] == [
confirmation_session_id
]
response = await client.get(
"/api/voice-sessions?needs_attention=true&attention_reason=safety_intervention"
)
assert response.status_code == 200
safety_sessions = response.json()
assert [item["id"] for item in safety_sessions] == [safety_session_id]
response = await client.get(
"/api/voice-sessions?needs_attention=true&attention_reason=failed_turn"
)
assert response.status_code == 200
failed_sessions = response.json()
assert [item["id"] for item in failed_sessions] == [failed_session_id]
response = await client.get("/api/voice-sessions/analytics?days=30")
assert response.status_code == 200
analytics = response.json()
assert analytics["total_sessions"] == 4
assert analytics["attention_sessions"] == 3
assert analytics["confirmation_attention_sessions"] == 1
assert analytics["safety_attention_sessions"] == 1
assert analytics["failed_attention_sessions"] == 1
assert analytics["failed_turns"] >= 1
assert analytics["low_confidence_turns"] >= 1
assert analytics["safety_interventions"] >= 1
finally:
app.dependency_overrides.clear()
async def test_voice_session_list_orders_recent_sessions_first(
db_session,
auth_token,