feat: refine voice studio attention workflow
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
"""Voice co-creation session APIs."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
@@ -82,6 +84,10 @@ async def list_voice_sessions(
|
||||
le=settings.voice_session_max_list_limit,
|
||||
),
|
||||
active_only: bool = Query(default=False),
|
||||
needs_attention: bool = Query(default=False),
|
||||
attention_reason: (
|
||||
Literal["pending_confirmation", "safety_intervention", "failed_turn"] | None
|
||||
) = Query(default=None),
|
||||
active_first: bool = Query(default=True),
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
@@ -92,6 +98,8 @@ async def list_voice_sessions(
|
||||
db,
|
||||
limit=limit,
|
||||
active_only=active_only,
|
||||
needs_attention=needs_attention,
|
||||
attention_reason=attention_reason,
|
||||
active_first=active_first,
|
||||
)
|
||||
|
||||
|
||||
@@ -121,6 +121,7 @@ class VoiceSessionSummaryResponse(BaseModel):
|
||||
latest_safety_message: str | None = None
|
||||
latest_assistant_audio_ready: bool = False
|
||||
last_turn_status: str | None = None
|
||||
attention_reasons: list[str] = Field(default_factory=list)
|
||||
transcription_mode_hint: str | None = None
|
||||
can_continue: bool = False
|
||||
can_finalize: bool = False
|
||||
@@ -149,6 +150,10 @@ class VoiceSessionAnalyticsResponse(BaseModel):
|
||||
|
||||
window_days: int | None = None
|
||||
total_sessions: int = 0
|
||||
attention_sessions: int = 0
|
||||
confirmation_attention_sessions: int = 0
|
||||
safety_attention_sessions: int = 0
|
||||
failed_attention_sessions: int = 0
|
||||
active_sessions: int = 0
|
||||
finalized_sessions: int = 0
|
||||
abandoned_sessions: int = 0
|
||||
|
||||
@@ -388,6 +388,12 @@ def _session_to_summary(
|
||||
story_patch=latest_turn.story_patch or {},
|
||||
)
|
||||
latest_safety_state = _resolve_turn_safety_state(latest_turn.story_patch or {})
|
||||
attention_reasons = _build_session_attention_reasons(
|
||||
latest_requires_confirmation=latest_confirmation_state["requires_confirmation"],
|
||||
latest_safety_flags=latest_safety_state["safety_flags"],
|
||||
last_turn_status=latest_turn.status if latest_turn else None,
|
||||
last_error=session.last_error,
|
||||
)
|
||||
|
||||
return VoiceSessionSummaryResponse(
|
||||
id=session.id,
|
||||
@@ -413,12 +419,55 @@ def _session_to_summary(
|
||||
session_audio_exists(latest_turn.assistant_audio_path) if latest_turn else False
|
||||
),
|
||||
last_turn_status=latest_turn.status if latest_turn else None,
|
||||
attention_reasons=attention_reasons,
|
||||
transcription_mode_hint=settings.voice_transcription_mode,
|
||||
can_continue=_session_can_continue(session),
|
||||
can_finalize=_can_finalize_with_latest_turn(session, latest_turn),
|
||||
last_error=session.last_error,
|
||||
created_at=session.created_at,
|
||||
updated_at=session.updated_at,
|
||||
)
|
||||
|
||||
|
||||
def _build_session_attention_reasons(
|
||||
*,
|
||||
latest_requires_confirmation: bool,
|
||||
latest_safety_flags: list[str] | None,
|
||||
last_turn_status: str | None,
|
||||
last_error: str | None,
|
||||
) -> list[str]:
|
||||
reasons: list[str] = []
|
||||
if latest_requires_confirmation:
|
||||
reasons.append("pending_confirmation")
|
||||
if latest_safety_flags:
|
||||
reasons.append("safety_intervention")
|
||||
if last_turn_status == "failed" or last_error:
|
||||
reasons.append("failed_turn")
|
||||
return reasons
|
||||
|
||||
|
||||
def _session_summary_needs_attention(summary: VoiceSessionSummaryResponse) -> bool:
|
||||
return bool(summary.attention_reasons)
|
||||
|
||||
|
||||
def _session_summary_matches_attention_reason(
|
||||
summary: VoiceSessionSummaryResponse,
|
||||
attention_reason: str | None,
|
||||
) -> bool:
|
||||
if attention_reason is None:
|
||||
return True
|
||||
return attention_reason in summary.attention_reasons
|
||||
|
||||
|
||||
async def _build_session_summary(
|
||||
db: AsyncSession,
|
||||
session: VoiceSession,
|
||||
) -> VoiceSessionSummaryResponse:
|
||||
latest_turn = await _get_latest_turn(db, session_id=session.id)
|
||||
return _session_to_summary(
|
||||
session,
|
||||
latest_turn=latest_turn,
|
||||
total_turns=session.current_turn_index,
|
||||
)
|
||||
|
||||
|
||||
@@ -1082,6 +1131,8 @@ async def list_voice_sessions_service(
|
||||
*,
|
||||
limit: int | None = None,
|
||||
active_only: bool = False,
|
||||
needs_attention: bool = False,
|
||||
attention_reason: str | None = None,
|
||||
active_first: bool = False,
|
||||
) -> list[VoiceSessionSummaryResponse]:
|
||||
resolved_limit = limit or settings.voice_session_default_list_limit
|
||||
@@ -1102,19 +1153,20 @@ async def list_voice_sessions_service(
|
||||
)
|
||||
else:
|
||||
query = query.order_by(desc(VoiceSession.updated_at), desc(VoiceSession.created_at))
|
||||
query = query.limit(resolved_limit)
|
||||
if not needs_attention and attention_reason is None:
|
||||
query = query.limit(resolved_limit)
|
||||
|
||||
sessions = (await db.execute(query)).scalars().all()
|
||||
summaries: list[VoiceSessionSummaryResponse] = []
|
||||
for session in sessions:
|
||||
latest_turn = await _get_latest_turn(db, session_id=session.id)
|
||||
summaries.append(
|
||||
_session_to_summary(
|
||||
session,
|
||||
latest_turn=latest_turn,
|
||||
total_turns=session.current_turn_index,
|
||||
)
|
||||
)
|
||||
summary = await _build_session_summary(db, session)
|
||||
if needs_attention and not _session_summary_needs_attention(summary):
|
||||
continue
|
||||
if not _session_summary_matches_attention_reason(summary, attention_reason):
|
||||
continue
|
||||
summaries.append(summary)
|
||||
if (needs_attention or attention_reason is not None) and len(summaries) >= resolved_limit:
|
||||
break
|
||||
return summaries
|
||||
|
||||
|
||||
@@ -1134,12 +1186,7 @@ async def get_latest_active_voice_session_service(
|
||||
session = (await db.execute(query)).scalar_one_or_none()
|
||||
if session is None:
|
||||
return None
|
||||
latest_turn = await _get_latest_turn(db, session_id=session.id)
|
||||
return _session_to_summary(
|
||||
session,
|
||||
latest_turn=latest_turn,
|
||||
total_turns=session.current_turn_index,
|
||||
)
|
||||
return await _build_session_summary(db, session)
|
||||
|
||||
|
||||
async def get_voice_session_analytics_service(
|
||||
@@ -1172,8 +1219,25 @@ async def get_voice_session_analytics_service(
|
||||
sessions = (await db.execute(session_query)).scalars().all()
|
||||
turns = (await db.execute(turn_query)).scalars().all()
|
||||
events = (await db.execute(event_query)).scalars().all()
|
||||
session_summaries = [await _build_session_summary(db, session) for session in sessions]
|
||||
|
||||
total_sessions = len(sessions)
|
||||
attention_sessions = sum(
|
||||
1 for summary in session_summaries if _session_summary_needs_attention(summary)
|
||||
)
|
||||
confirmation_attention_sessions = sum(
|
||||
1
|
||||
for summary in session_summaries
|
||||
if "pending_confirmation" in summary.attention_reasons
|
||||
)
|
||||
safety_attention_sessions = sum(
|
||||
1
|
||||
for summary in session_summaries
|
||||
if "safety_intervention" in summary.attention_reasons
|
||||
)
|
||||
failed_attention_sessions = sum(
|
||||
1 for summary in session_summaries if "failed_turn" in summary.attention_reasons
|
||||
)
|
||||
active_sessions = sum(
|
||||
1 for session in sessions if session.status in CONTINUABLE_SESSION_STATUSES
|
||||
)
|
||||
@@ -1205,6 +1269,10 @@ async def get_voice_session_analytics_service(
|
||||
return VoiceSessionAnalyticsResponse(
|
||||
window_days=days,
|
||||
total_sessions=total_sessions,
|
||||
attention_sessions=attention_sessions,
|
||||
confirmation_attention_sessions=confirmation_attention_sessions,
|
||||
safety_attention_sessions=safety_attention_sessions,
|
||||
failed_attention_sessions=failed_attention_sessions,
|
||||
active_sessions=active_sessions,
|
||||
finalized_sessions=finalized_sessions,
|
||||
abandoned_sessions=abandoned_sessions,
|
||||
|
||||
@@ -681,6 +681,149 @@ async def test_voice_session_analytics_summarize_failures_and_confirmations(
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
async def test_voice_session_attention_filter_and_analytics_count(
|
||||
db_session,
|
||||
auth_token,
|
||||
):
|
||||
async def override_get_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.voice_session_service.generate_story_content",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_generate,
|
||||
patch(
|
||||
"app.services.voice_session_service.text_to_speech",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_tts,
|
||||
patch(
|
||||
"app.services.voice_session_service.transcribe_voice_audio",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_transcribe,
|
||||
):
|
||||
mock_generate.side_effect = [
|
||||
StoryOutput(
|
||||
mode="generated",
|
||||
title="正常故事",
|
||||
story_text="第一段温暖故事。",
|
||||
cover_prompt_suggestion="normal cover",
|
||||
),
|
||||
RuntimeError("provider down"),
|
||||
]
|
||||
mock_tts.side_effect = [
|
||||
b"normal-audio",
|
||||
b"confirmation-audio",
|
||||
b"safety-audio",
|
||||
]
|
||||
mock_transcribe.return_value = VoiceTranscriptionResult(
|
||||
transcript_text="我想听一个会发光的小恐龙故事",
|
||||
confidence=0.41,
|
||||
provider="openai",
|
||||
)
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
try:
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
client.cookies.set("access_token", auth_token)
|
||||
|
||||
response = await client.post("/api/voice-sessions", json={})
|
||||
normal_session_id = response.json()["id"]
|
||||
response = await client.post(
|
||||
f"/api/voice-sessions/{normal_session_id}/turns/fallback",
|
||||
json={"transcript_text": "先讲一个温暖的普通故事"},
|
||||
)
|
||||
assert response.status_code == 202
|
||||
|
||||
response = await client.post("/api/voice-sessions", json={})
|
||||
failed_session_id = response.json()["id"]
|
||||
response = await client.post(
|
||||
f"/api/voice-sessions/{failed_session_id}/turns/fallback",
|
||||
json={"transcript_text": "这轮会触发 provider 异常"},
|
||||
)
|
||||
assert response.status_code == 202
|
||||
|
||||
response = await client.post("/api/voice-sessions", json={})
|
||||
confirmation_session_id = response.json()["id"]
|
||||
response = await client.post(
|
||||
f"/api/voice-sessions/{confirmation_session_id}/turns",
|
||||
files={
|
||||
"audio_file": ("turn.webm", b"fake-webm-audio", "audio/webm"),
|
||||
},
|
||||
)
|
||||
assert response.status_code == 202
|
||||
|
||||
response = await client.post("/api/voice-sessions", json={})
|
||||
safety_session_id = response.json()["id"]
|
||||
response = await client.post(
|
||||
f"/api/voice-sessions/{safety_session_id}/turns/fallback",
|
||||
json={"transcript_text": "我想听一个拿着炸弹互相打的故事"},
|
||||
)
|
||||
assert response.status_code == 202
|
||||
|
||||
response = await client.get(
|
||||
"/api/voice-sessions?needs_attention=true&limit=8"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
attention_sessions = response.json()
|
||||
attention_session_ids = {item["id"] for item in attention_sessions}
|
||||
assert attention_session_ids == {
|
||||
failed_session_id,
|
||||
confirmation_session_id,
|
||||
safety_session_id,
|
||||
}
|
||||
assert normal_session_id not in attention_session_ids
|
||||
attention_reason_sets = {
|
||||
item["id"]: set(item["attention_reasons"]) for item in attention_sessions
|
||||
}
|
||||
assert attention_reason_sets[confirmation_session_id] == {
|
||||
"pending_confirmation"
|
||||
}
|
||||
assert attention_reason_sets[safety_session_id] == {
|
||||
"safety_intervention"
|
||||
}
|
||||
assert attention_reason_sets[failed_session_id] == {"failed_turn"}
|
||||
|
||||
response = await client.get(
|
||||
"/api/voice-sessions?needs_attention=true&attention_reason=pending_confirmation"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
confirmation_sessions = response.json()
|
||||
assert [item["id"] for item in confirmation_sessions] == [
|
||||
confirmation_session_id
|
||||
]
|
||||
|
||||
response = await client.get(
|
||||
"/api/voice-sessions?needs_attention=true&attention_reason=safety_intervention"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
safety_sessions = response.json()
|
||||
assert [item["id"] for item in safety_sessions] == [safety_session_id]
|
||||
|
||||
response = await client.get(
|
||||
"/api/voice-sessions?needs_attention=true&attention_reason=failed_turn"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
failed_sessions = response.json()
|
||||
assert [item["id"] for item in failed_sessions] == [failed_session_id]
|
||||
|
||||
response = await client.get("/api/voice-sessions/analytics?days=30")
|
||||
assert response.status_code == 200
|
||||
analytics = response.json()
|
||||
assert analytics["total_sessions"] == 4
|
||||
assert analytics["attention_sessions"] == 3
|
||||
assert analytics["confirmation_attention_sessions"] == 1
|
||||
assert analytics["safety_attention_sessions"] == 1
|
||||
assert analytics["failed_attention_sessions"] == 1
|
||||
assert analytics["failed_turns"] >= 1
|
||||
assert analytics["low_confidence_turns"] >= 1
|
||||
assert analytics["safety_interventions"] >= 1
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
async def test_voice_session_list_orders_recent_sessions_first(
|
||||
db_session,
|
||||
auth_token,
|
||||
|
||||
Reference in New Issue
Block a user