feat: complete voice session safety and confirmation flow
This commit is contained in:
@@ -19,12 +19,14 @@ from app.db.database import get_db
|
||||
from app.db.models import User
|
||||
from app.schemas.voice_session_schemas import (
|
||||
VoiceSessionAbandonRequest,
|
||||
VoiceSessionAnalyticsResponse,
|
||||
VoiceSessionCreateRequest,
|
||||
VoiceSessionDetailResponse,
|
||||
VoiceSessionFinalizeRequest,
|
||||
VoiceSessionFinalizeResponse,
|
||||
VoiceSessionSummaryResponse,
|
||||
VoiceTurnAcceptedResponse,
|
||||
VoiceTurnConfirmRequest,
|
||||
VoiceTurnCreateFallbackRequest,
|
||||
VoiceTurnSummaryResponse,
|
||||
VoiceTurnUploadAcceptedResponse,
|
||||
@@ -36,11 +38,13 @@ from app.services.voice_session_service import (
|
||||
create_voice_turn_from_upload_service,
|
||||
finalize_voice_session_service,
|
||||
get_latest_active_voice_session_service,
|
||||
get_voice_session_analytics_service,
|
||||
get_voice_session_detail_service,
|
||||
get_voice_turn_audio_service,
|
||||
get_voice_turn_service,
|
||||
get_voice_turn_user_audio_service,
|
||||
list_voice_sessions_service,
|
||||
resolve_voice_turn_confirmation_service,
|
||||
retry_voice_turn_audio_service,
|
||||
retry_voice_turn_service,
|
||||
)
|
||||
@@ -101,6 +105,16 @@ async def get_latest_active_voice_session(
|
||||
return await get_latest_active_voice_session_service(user.id, db)
|
||||
|
||||
|
||||
@router.get("/voice-sessions/analytics", response_model=VoiceSessionAnalyticsResponse)
|
||||
async def get_voice_session_analytics(
|
||||
days: int | None = Query(default=30, ge=1, le=365),
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get aggregate voice co-creation analytics for the current user."""
|
||||
return await get_voice_session_analytics_service(user.id, db, days=days)
|
||||
|
||||
|
||||
@router.get("/voice-sessions/{session_id}", response_model=VoiceSessionDetailResponse)
|
||||
async def get_voice_session(
|
||||
session_id: str,
|
||||
@@ -192,6 +206,27 @@ async def retry_voice_turn(
|
||||
return await retry_voice_turn_service(session_id, turn_id, user.id, db)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/voice-sessions/{session_id}/turns/{turn_id}/confirm",
|
||||
response_model=VoiceTurnSummaryResponse,
|
||||
)
|
||||
async def resolve_voice_turn_confirmation(
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
request: VoiceTurnConfirmRequest,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Resolve one pending confirmation before continuing the session."""
|
||||
return await resolve_voice_turn_confirmation_service(
|
||||
session_id,
|
||||
turn_id,
|
||||
request,
|
||||
user.id,
|
||||
db,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/voice-sessions/{session_id}/turns/{turn_id}/audio")
|
||||
async def get_voice_turn_audio(
|
||||
session_id: str,
|
||||
|
||||
@@ -42,6 +42,12 @@ class VoiceSessionFinalizeRequest(BaseModel):
|
||||
generate_final_audio: bool = False
|
||||
|
||||
|
||||
class VoiceTurnConfirmRequest(BaseModel):
|
||||
"""Resolve one pending confirmation before the story continues."""
|
||||
|
||||
action: Literal["accept", "retry_recording", "switch_to_text"]
|
||||
|
||||
|
||||
class VoiceSessionAbandonRequest(BaseModel):
|
||||
"""Explicitly abandon one in-progress session."""
|
||||
|
||||
@@ -75,8 +81,12 @@ class VoiceTurnSummaryResponse(BaseModel):
|
||||
intent_confidence: float | None = None
|
||||
understanding_summary: str | None = None
|
||||
requires_confirmation: bool = False
|
||||
confirmation_state: str = "not_needed"
|
||||
confirmation_reason: str | None = None
|
||||
confirmation_message: str | None = None
|
||||
safety_flags: list[str] = Field(default_factory=list)
|
||||
safety_blocked: bool = False
|
||||
safety_message: str | None = None
|
||||
assistant_text: str | None = None
|
||||
assistant_audio_ready: bool = False
|
||||
assistant_audio_url: str | None = None
|
||||
@@ -105,7 +115,10 @@ class VoiceSessionSummaryResponse(BaseModel):
|
||||
latest_detected_intent: str | None = None
|
||||
latest_understanding_summary: str | None = None
|
||||
latest_requires_confirmation: bool = False
|
||||
latest_confirmation_state: str | None = None
|
||||
latest_confirmation_message: str | None = None
|
||||
latest_safety_flags: list[str] = Field(default_factory=list)
|
||||
latest_safety_message: str | None = None
|
||||
latest_assistant_audio_ready: bool = False
|
||||
last_turn_status: str | None = None
|
||||
transcription_mode_hint: str | None = None
|
||||
@@ -131,6 +144,25 @@ class VoiceTurnAcceptedResponse(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
class VoiceSessionAnalyticsResponse(BaseModel):
|
||||
"""Aggregated voice co-creation analytics for one user."""
|
||||
|
||||
window_days: int | None = None
|
||||
total_sessions: int = 0
|
||||
active_sessions: int = 0
|
||||
finalized_sessions: int = 0
|
||||
abandoned_sessions: int = 0
|
||||
total_turns: int = 0
|
||||
successful_turns: int = 0
|
||||
failed_turns: int = 0
|
||||
asr_failures: int = 0
|
||||
tts_failures: int = 0
|
||||
low_confidence_turns: int = 0
|
||||
safety_interventions: int = 0
|
||||
turn_success_rate: float = 0.0
|
||||
finalize_conversion_rate: float = 0.0
|
||||
|
||||
|
||||
class VoiceSessionFinalizeResponse(BaseModel):
|
||||
"""Finalize response after a session is converted into a story."""
|
||||
|
||||
|
||||
135
backend/app/services/voice_session_safety.py
Normal file
135
backend/app/services/voice_session_safety.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Safety helpers for child-friendly voice co-creation sessions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
UNSAFE_KEYWORD_GROUPS: dict[str, tuple[str, ...]] = {
|
||||
"violence": (
|
||||
"打死",
|
||||
"杀掉",
|
||||
"砍伤",
|
||||
"流很多血",
|
||||
"炸弹",
|
||||
"爆炸",
|
||||
"开枪",
|
||||
"刀子",
|
||||
"互相打",
|
||||
),
|
||||
"horror": (
|
||||
"鬼屋",
|
||||
"鬼怪",
|
||||
"僵尸",
|
||||
"诅咒",
|
||||
"恶魔",
|
||||
"吃人",
|
||||
"恐怖",
|
||||
"吓死人",
|
||||
),
|
||||
"danger": (
|
||||
"毒药",
|
||||
"绑架",
|
||||
"自杀",
|
||||
"跳楼",
|
||||
"伤害自己",
|
||||
"把人关起来",
|
||||
),
|
||||
"adult": (
|
||||
"色情",
|
||||
"裸",
|
||||
"亲热",
|
||||
"不穿衣服",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VoiceSafetyResult:
|
||||
"""Result of one voice safety evaluation."""
|
||||
|
||||
is_safe: bool
|
||||
flags: list[str]
|
||||
replacement_text: str | None = None
|
||||
message: str | None = None
|
||||
|
||||
|
||||
def _collect_safety_flags(text: str) -> list[str]:
|
||||
normalized = text.replace(" ", "").strip()
|
||||
flags: list[str] = []
|
||||
if not normalized:
|
||||
return flags
|
||||
|
||||
for flag, keywords in UNSAFE_KEYWORD_GROUPS.items():
|
||||
if any(keyword in normalized for keyword in keywords):
|
||||
flags.append(flag)
|
||||
return flags
|
||||
|
||||
|
||||
def _redirect_prefix(flags: list[str]) -> str:
|
||||
if "adult" in flags:
|
||||
return "这个方向不适合小朋友的睡前故事。"
|
||||
if "danger" in flags or "violence" in flags:
|
||||
return "这个方向有点太危险了。"
|
||||
if "horror" in flags:
|
||||
return "这个方向有点太吓人了。"
|
||||
return "这个方向现在不太适合继续讲下去。"
|
||||
|
||||
|
||||
def build_child_safe_redirect(flags: list[str]) -> str:
|
||||
"""Build a child-friendly redirect prompt after an unsafe request."""
|
||||
|
||||
return (
|
||||
f"{_redirect_prefix(flags)}"
|
||||
"我们把它改成温柔、安全、适合小朋友的冒险吧。"
|
||||
"你可以试试说:让小伙伴一起想办法、让事情变得更明亮,或者让新朋友来帮忙。"
|
||||
)
|
||||
|
||||
|
||||
def build_safe_story_fallback(*, premise: str | None = None) -> str:
|
||||
"""Build a safe replacement narrative segment for unsafe assistant output."""
|
||||
|
||||
subject = (premise or "小伙伴们").strip()
|
||||
if len(subject) > 12:
|
||||
subject = subject[:12]
|
||||
|
||||
return (
|
||||
f"{subject}决定把眼前的难题变成一次温柔又勇敢的冒险。"
|
||||
"大家先停下来想一想,再一起找到一个善良、安全、让人安心的解决办法,"
|
||||
"故事也朝着明亮的方向继续展开。"
|
||||
)
|
||||
|
||||
|
||||
def check_user_transcript_safety(transcript_text: str) -> VoiceSafetyResult:
|
||||
"""Screen user transcript text before it enters the story flow."""
|
||||
|
||||
flags = _collect_safety_flags(transcript_text)
|
||||
if not flags:
|
||||
return VoiceSafetyResult(is_safe=True, flags=[])
|
||||
|
||||
message = build_child_safe_redirect(flags)
|
||||
return VoiceSafetyResult(
|
||||
is_safe=False,
|
||||
flags=flags,
|
||||
replacement_text=message,
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
def check_assistant_output_safety(
|
||||
assistant_text: str,
|
||||
*,
|
||||
premise: str | None = None,
|
||||
) -> VoiceSafetyResult:
|
||||
"""Screen assistant output and replace it with a child-safe segment when needed."""
|
||||
|
||||
flags = _collect_safety_flags(assistant_text)
|
||||
if not flags:
|
||||
return VoiceSafetyResult(is_safe=True, flags=[])
|
||||
|
||||
replacement_text = build_safe_story_fallback(premise=premise)
|
||||
return VoiceSafetyResult(
|
||||
is_safe=False,
|
||||
flags=flags,
|
||||
replacement_text=replacement_text,
|
||||
message="系统已把不适合孩子的内容改写为更温和安全的版本。",
|
||||
)
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException
|
||||
@@ -14,12 +14,14 @@ from app.core.logging import get_logger
|
||||
from app.db.models import VoiceSession, VoiceSessionEvent, VoiceTurn
|
||||
from app.schemas.voice_session_schemas import (
|
||||
VoiceSessionAbandonRequest,
|
||||
VoiceSessionAnalyticsResponse,
|
||||
VoiceSessionCreateRequest,
|
||||
VoiceSessionDetailResponse,
|
||||
VoiceSessionFinalizeRequest,
|
||||
VoiceSessionFinalizeResponse,
|
||||
VoiceSessionSummaryResponse,
|
||||
VoiceTurnAcceptedResponse,
|
||||
VoiceTurnConfirmRequest,
|
||||
VoiceTurnCreateFallbackRequest,
|
||||
VoiceTurnSummaryResponse,
|
||||
VoiceTurnUploadAcceptedResponse,
|
||||
@@ -27,7 +29,15 @@ from app.schemas.voice_session_schemas import (
|
||||
from app.services.adapters.text.models import StoryOutput
|
||||
from app.services.memory_service import build_enhanced_memory_context
|
||||
from app.services.provider_router import generate_story_content, text_to_speech
|
||||
from app.services.story_service import create_story_from_result, validate_profile_and_universe
|
||||
from app.services.story_service import (
|
||||
create_story_from_result,
|
||||
generate_story_cover,
|
||||
validate_profile_and_universe,
|
||||
)
|
||||
from app.services.voice_session_safety import (
|
||||
check_assistant_output_safety,
|
||||
check_user_transcript_safety,
|
||||
)
|
||||
from app.services.voice_session_storage import (
|
||||
build_turn_assistant_audio_path,
|
||||
read_session_audio,
|
||||
@@ -51,6 +61,7 @@ def _default_story_state() -> dict[str, Any]:
|
||||
"narrative_segments": [],
|
||||
"safety_flags": [],
|
||||
"last_intent": None,
|
||||
"final_summary": None,
|
||||
}
|
||||
|
||||
|
||||
@@ -121,7 +132,9 @@ def _build_confirmation_message(
|
||||
f"{normalized_transcript}。"
|
||||
)
|
||||
else:
|
||||
natural_understanding = f"我现在先理解成你想「{_format_intent_label(detected_intent)}」。"
|
||||
natural_understanding = (
|
||||
f"我现在先理解成你想「{_format_intent_label(detected_intent)}」。"
|
||||
)
|
||||
|
||||
if "intent_unknown" in reasons:
|
||||
prefix = "我这一次还没有完全听懂。"
|
||||
@@ -141,6 +154,34 @@ def _build_confirmation_message(
|
||||
)
|
||||
|
||||
|
||||
def _merge_unique_items(*values: list[str] | tuple[str, ...]) -> list[str]:
|
||||
merged: list[str] = []
|
||||
for value in values:
|
||||
for item in value:
|
||||
normalized = str(item).strip()
|
||||
if normalized and normalized not in merged:
|
||||
merged.append(normalized)
|
||||
return merged
|
||||
|
||||
|
||||
def _confirmation_state_from_patch(story_patch: dict[str, Any] | None = None) -> str:
|
||||
patch = story_patch or {}
|
||||
if isinstance(patch.get("confirmation_state"), str):
|
||||
return str(patch["confirmation_state"])
|
||||
if patch.get("requires_confirmation"):
|
||||
return "pending"
|
||||
return "not_needed"
|
||||
|
||||
|
||||
def _resolve_turn_safety_state(story_patch: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
patch = story_patch or {}
|
||||
return {
|
||||
"safety_flags": list(patch.get("safety_flags") or []),
|
||||
"safety_blocked": bool(patch.get("safety_blocked") or False),
|
||||
"safety_message": patch.get("safety_message"),
|
||||
}
|
||||
|
||||
|
||||
def _resolve_turn_confirmation_state(
|
||||
*,
|
||||
transcript_text: str | None,
|
||||
@@ -154,6 +195,7 @@ def _resolve_turn_confirmation_state(
|
||||
confirmation_reason = patch.get("confirmation_reason")
|
||||
confirmation_message = patch.get("confirmation_message")
|
||||
understanding_summary = patch.get("understanding_summary")
|
||||
confirmation_state = _confirmation_state_from_patch(patch)
|
||||
|
||||
reasons: list[str] = []
|
||||
if detected_intent == "unknown":
|
||||
@@ -188,11 +230,93 @@ def _resolve_turn_confirmation_state(
|
||||
return {
|
||||
"understanding_summary": understanding_summary,
|
||||
"requires_confirmation": bool(requires_confirmation),
|
||||
"confirmation_state": confirmation_state,
|
||||
"confirmation_reason": confirmation_reason,
|
||||
"confirmation_message": confirmation_message,
|
||||
}
|
||||
|
||||
|
||||
def _turn_has_pending_confirmation(turn: VoiceTurn) -> bool:
|
||||
confirmation_state = _resolve_turn_confirmation_state(
|
||||
transcript_text=turn.user_transcript,
|
||||
transcript_confidence=turn.transcript_confidence,
|
||||
detected_intent=turn.detected_intent,
|
||||
intent_confidence=turn.intent_confidence,
|
||||
story_patch=turn.story_patch or {},
|
||||
)
|
||||
return confirmation_state["requires_confirmation"] and (
|
||||
confirmation_state["confirmation_state"] == "pending"
|
||||
)
|
||||
|
||||
|
||||
def _extract_first_sentence(text: str | None) -> str:
|
||||
normalized = (text or "").strip().replace("\n", " ")
|
||||
if not normalized:
|
||||
return ""
|
||||
for separator in ("。", "!", "?", ".", "!", "?"):
|
||||
if separator in normalized:
|
||||
return normalized.split(separator, 1)[0].strip()
|
||||
return normalized
|
||||
|
||||
|
||||
def _build_final_story_title(session: VoiceSession) -> str:
|
||||
candidates = [
|
||||
session.working_title,
|
||||
(session.story_state or {}).get("premise"),
|
||||
_extract_first_sentence(
|
||||
((session.story_state or {}).get("narrative_segments") or [None])[0]
|
||||
),
|
||||
"一起编织的睡前故事",
|
||||
]
|
||||
for candidate in candidates:
|
||||
normalized = str(candidate or "").strip(" \n\t。!?::-")
|
||||
if normalized:
|
||||
return normalized[:24]
|
||||
return "一起编织的睡前故事"
|
||||
|
||||
|
||||
def _build_final_story_summary(session: VoiceSession) -> str:
|
||||
story_state = session.story_state or {}
|
||||
segments = [
|
||||
segment.strip()
|
||||
for segment in list(story_state.get("narrative_segments") or [])
|
||||
if str(segment).strip()
|
||||
]
|
||||
if not segments:
|
||||
return "这是一段由孩子和 DreamWeaver 一起共创的温柔故事。"
|
||||
|
||||
first_sentence = _extract_first_sentence(segments[0])
|
||||
last_sentence = _extract_first_sentence(segments[-1])
|
||||
if first_sentence and last_sentence and first_sentence != last_sentence:
|
||||
return f"{first_sentence}。后来,{last_sentence}。"
|
||||
if first_sentence:
|
||||
return f"{first_sentence}。"
|
||||
return "这是一段由孩子和 DreamWeaver 一起共创的温柔故事。"
|
||||
|
||||
|
||||
def _turn_counts_as_success(turn: VoiceTurn) -> bool:
|
||||
patch = turn.story_patch or {}
|
||||
confirmation_state = _confirmation_state_from_patch(patch)
|
||||
if turn.status == "failed":
|
||||
return False
|
||||
if patch.get("safety_blocked"):
|
||||
return False
|
||||
if confirmation_state in {"pending", "retry_recording", "switch_to_text"}:
|
||||
return False
|
||||
return turn.status in {"audio_ready", "narrative_ready"}
|
||||
|
||||
|
||||
def _can_finalize_with_latest_turn(
|
||||
session: VoiceSession,
|
||||
latest_turn: VoiceTurn | None,
|
||||
) -> bool:
|
||||
if not _session_can_finalize(session):
|
||||
return False
|
||||
if latest_turn and _turn_has_pending_confirmation(latest_turn):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _turn_to_summary(turn: VoiceTurn) -> VoiceTurnSummaryResponse:
|
||||
turn_patch = turn.story_patch or {}
|
||||
confirmation_state = _resolve_turn_confirmation_state(
|
||||
@@ -202,6 +326,7 @@ def _turn_to_summary(turn: VoiceTurn) -> VoiceTurnSummaryResponse:
|
||||
intent_confidence=turn.intent_confidence,
|
||||
story_patch=turn_patch,
|
||||
)
|
||||
safety_state = _resolve_turn_safety_state(turn_patch)
|
||||
return VoiceTurnSummaryResponse(
|
||||
id=turn.id,
|
||||
session_id=turn.session_id,
|
||||
@@ -214,8 +339,12 @@ def _turn_to_summary(turn: VoiceTurn) -> VoiceTurnSummaryResponse:
|
||||
intent_confidence=turn.intent_confidence,
|
||||
understanding_summary=confirmation_state["understanding_summary"],
|
||||
requires_confirmation=confirmation_state["requires_confirmation"],
|
||||
confirmation_state=confirmation_state["confirmation_state"],
|
||||
confirmation_reason=confirmation_state["confirmation_reason"],
|
||||
confirmation_message=confirmation_state["confirmation_message"],
|
||||
safety_flags=safety_state["safety_flags"],
|
||||
safety_blocked=safety_state["safety_blocked"],
|
||||
safety_message=safety_state["safety_message"],
|
||||
assistant_text=turn.assistant_text,
|
||||
assistant_audio_ready=session_audio_exists(turn.assistant_audio_path),
|
||||
assistant_audio_url=_assistant_audio_url(
|
||||
@@ -242,8 +371,13 @@ def _session_to_summary(
|
||||
latest_confirmation_state = {
|
||||
"understanding_summary": None,
|
||||
"requires_confirmation": False,
|
||||
"confirmation_state": None,
|
||||
"confirmation_message": None,
|
||||
}
|
||||
latest_safety_state = {
|
||||
"safety_flags": [],
|
||||
"safety_message": None,
|
||||
}
|
||||
else:
|
||||
total_turns = total_turns if total_turns is not None else latest_turn.turn_index
|
||||
latest_confirmation_state = _resolve_turn_confirmation_state(
|
||||
@@ -253,6 +387,7 @@ def _session_to_summary(
|
||||
intent_confidence=latest_turn.intent_confidence,
|
||||
story_patch=latest_turn.story_patch or {},
|
||||
)
|
||||
latest_safety_state = _resolve_turn_safety_state(latest_turn.story_patch or {})
|
||||
|
||||
return VoiceSessionSummaryResponse(
|
||||
id=session.id,
|
||||
@@ -270,14 +405,17 @@ def _session_to_summary(
|
||||
latest_detected_intent=latest_turn.detected_intent if latest_turn else None,
|
||||
latest_understanding_summary=latest_confirmation_state["understanding_summary"],
|
||||
latest_requires_confirmation=latest_confirmation_state["requires_confirmation"],
|
||||
latest_confirmation_state=latest_confirmation_state["confirmation_state"],
|
||||
latest_confirmation_message=latest_confirmation_state["confirmation_message"],
|
||||
latest_safety_flags=latest_safety_state["safety_flags"],
|
||||
latest_safety_message=latest_safety_state["safety_message"],
|
||||
latest_assistant_audio_ready=(
|
||||
session_audio_exists(latest_turn.assistant_audio_path) if latest_turn else False
|
||||
),
|
||||
last_turn_status=latest_turn.status if latest_turn else None,
|
||||
transcription_mode_hint=settings.voice_transcription_mode,
|
||||
can_continue=_session_can_continue(session),
|
||||
can_finalize=_session_can_finalize(session),
|
||||
can_finalize=_can_finalize_with_latest_turn(session, latest_turn),
|
||||
last_error=session.last_error,
|
||||
created_at=session.created_at,
|
||||
updated_at=session.updated_at,
|
||||
@@ -468,6 +606,7 @@ def _merge_story_state(
|
||||
transcript_text: str,
|
||||
intent: str,
|
||||
assistant_result: StoryOutput | None,
|
||||
safety_flags: list[str] | None = None,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
current_state = _default_story_state() | (session.story_state or {})
|
||||
narrative_segments = list(current_state.get("narrative_segments") or [])
|
||||
@@ -481,6 +620,10 @@ def _merge_story_state(
|
||||
current_state["narrative_segments"] = narrative_segments
|
||||
current_state["latest_direction"] = transcript_text
|
||||
current_state["last_intent"] = intent
|
||||
current_state["safety_flags"] = _merge_unique_items(
|
||||
list(current_state.get("safety_flags") or []),
|
||||
list(safety_flags or []),
|
||||
)
|
||||
if assistant_result and assistant_result.cover_prompt_suggestion:
|
||||
current_state["cover_prompt"] = assistant_result.cover_prompt_suggestion
|
||||
|
||||
@@ -491,10 +634,24 @@ def _merge_story_state(
|
||||
"working_title": assistant_result.title if assistant_result else session.working_title,
|
||||
"cover_prompt": current_state.get("cover_prompt"),
|
||||
"narrative_segments_count": len(narrative_segments),
|
||||
"safety_flags": list(current_state.get("safety_flags") or []),
|
||||
}
|
||||
return current_state, patch
|
||||
|
||||
|
||||
async def _ensure_no_pending_confirmation(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
session: VoiceSession,
|
||||
) -> None:
|
||||
latest_turn = await _get_latest_turn(db, session_id=session.id)
|
||||
if latest_turn and _turn_has_pending_confirmation(latest_turn):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="请先确认上一轮系统理解,或选择重说 / 改成文本输入后再继续。",
|
||||
)
|
||||
|
||||
|
||||
async def _create_pending_turn(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
@@ -511,6 +668,7 @@ async def _create_pending_turn(
|
||||
status_code=409,
|
||||
detail="Voice session is not ready for another turn.",
|
||||
)
|
||||
await _ensure_no_pending_confirmation(db, session=session)
|
||||
|
||||
next_turn_index = session.current_turn_index + 1
|
||||
detected_intent, intent_confidence = _detect_intent(
|
||||
@@ -593,13 +751,18 @@ async def _process_pending_turn(
|
||||
assistant_result: StoryOutput | None = None
|
||||
detected_intent = turn.detected_intent
|
||||
intent_confidence = turn.intent_confidence
|
||||
turn_patch = dict(turn.story_patch or {})
|
||||
confirmation_state = _resolve_turn_confirmation_state(
|
||||
transcript_text=transcript_text,
|
||||
transcript_confidence=turn.transcript_confidence,
|
||||
detected_intent=detected_intent,
|
||||
intent_confidence=intent_confidence,
|
||||
story_patch=turn.story_patch or {},
|
||||
story_patch=turn_patch,
|
||||
)
|
||||
transcript_safety = check_user_transcript_safety(transcript_text)
|
||||
assistant_safety_message: str | None = None
|
||||
safety_flags: list[str] = []
|
||||
transcript_blocked = False
|
||||
|
||||
try:
|
||||
await _record_session_event(
|
||||
@@ -669,6 +832,70 @@ async def _process_pending_turn(
|
||||
"requires_confirmation": True,
|
||||
},
|
||||
)
|
||||
elif not transcript_safety.is_safe:
|
||||
transcript_blocked = True
|
||||
safety_flags = list(transcript_safety.flags)
|
||||
current_state = _default_story_state() | (session.story_state or {})
|
||||
current_state["safety_flags"] = _merge_unique_items(
|
||||
list(current_state.get("safety_flags") or []),
|
||||
safety_flags,
|
||||
)
|
||||
assistant_text = transcript_safety.replacement_text or transcript_safety.message
|
||||
turn.story_patch = {
|
||||
**turn_patch,
|
||||
"intent": detected_intent,
|
||||
"transcript_text": transcript_text,
|
||||
"segment_added": False,
|
||||
"working_title": session.working_title,
|
||||
"cover_prompt": current_state.get("cover_prompt"),
|
||||
"narrative_segments_count": len(
|
||||
list(current_state.get("narrative_segments") or [])
|
||||
),
|
||||
"requires_confirmation": False,
|
||||
"confirmation_state": turn_patch.get("confirmation_state", "not_needed"),
|
||||
"understanding_summary": confirmation_state["understanding_summary"],
|
||||
"safety_flags": safety_flags,
|
||||
"safety_blocked": True,
|
||||
"safety_message": transcript_safety.message,
|
||||
}
|
||||
turn.assistant_text = assistant_text
|
||||
turn.status = "narrative_ready"
|
||||
turn.error_message = None
|
||||
session.story_state = current_state
|
||||
session.latest_assistant_text = assistant_text
|
||||
session.status = "waiting_user"
|
||||
session.last_error = None
|
||||
session.updated_at = _utcnow()
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
await db.refresh(turn)
|
||||
|
||||
await _record_session_event(
|
||||
db,
|
||||
session_id=session.id,
|
||||
turn_id=turn.id,
|
||||
event_type="safety_intervention_requested",
|
||||
status="blocked",
|
||||
message="Unsafe user transcript was redirected to a child-friendly path.",
|
||||
metadata={
|
||||
"stage": "user_input",
|
||||
"safety_flags": safety_flags,
|
||||
},
|
||||
)
|
||||
await _record_session_event(
|
||||
db,
|
||||
session_id=session.id,
|
||||
turn_id=turn.id,
|
||||
event_type="assistant_text_ready",
|
||||
status="succeeded",
|
||||
message="Assistant safety redirect generated.",
|
||||
metadata={
|
||||
"assistant_text_length": len(assistant_text or ""),
|
||||
"working_title": session.working_title,
|
||||
"requires_confirmation": False,
|
||||
"safety_flags": safety_flags,
|
||||
},
|
||||
)
|
||||
elif detected_intent == "save_story":
|
||||
assistant_text = "好的,这个故事已经准备好保存到故事库了。"
|
||||
elif detected_intent == "end_story":
|
||||
@@ -681,23 +908,47 @@ async def _process_pending_turn(
|
||||
intent=detected_intent,
|
||||
)
|
||||
assistant_text = assistant_result.story_text.strip()
|
||||
output_safety = check_assistant_output_safety(
|
||||
assistant_text,
|
||||
premise=str((session.story_state or {}).get("premise") or ""),
|
||||
)
|
||||
if not output_safety.is_safe:
|
||||
safety_flags = _merge_unique_items(safety_flags, output_safety.flags)
|
||||
assistant_safety_message = output_safety.message
|
||||
assistant_text = output_safety.replacement_text or assistant_text
|
||||
assistant_result = StoryOutput(
|
||||
mode=assistant_result.mode,
|
||||
title=assistant_result.title,
|
||||
story_text=assistant_text,
|
||||
cover_prompt_suggestion=assistant_result.cover_prompt_suggestion,
|
||||
)
|
||||
|
||||
if not confirmation_state["requires_confirmation"]:
|
||||
if not confirmation_state["requires_confirmation"] and not transcript_blocked:
|
||||
merged_state, story_patch = _merge_story_state(
|
||||
session,
|
||||
transcript_text=transcript_text,
|
||||
intent=detected_intent,
|
||||
assistant_result=assistant_result,
|
||||
safety_flags=safety_flags,
|
||||
)
|
||||
story_patch["transcription_provider"] = (
|
||||
(turn.story_patch or {}).get("transcription_provider")
|
||||
)
|
||||
story_patch["transcription_provider"] = turn_patch.get("transcription_provider")
|
||||
story_patch["requires_confirmation"] = False
|
||||
story_patch["confirmation_state"] = turn_patch.get("confirmation_state", "not_needed")
|
||||
story_patch["understanding_summary"] = confirmation_state["understanding_summary"]
|
||||
if turn_patch.get("confirmation_reason"):
|
||||
story_patch["confirmation_reason"] = turn_patch.get("confirmation_reason")
|
||||
story_patch["confirmation_message"] = None
|
||||
story_patch["safety_flags"] = safety_flags
|
||||
story_patch["safety_blocked"] = False
|
||||
story_patch["safety_message"] = assistant_safety_message
|
||||
turn.story_patch = story_patch
|
||||
turn.assistant_text = assistant_text
|
||||
turn.status = "narrative_ready"
|
||||
turn.error_message = None
|
||||
session.story_state = merged_state
|
||||
session.latest_assistant_text = assistant_text
|
||||
session.status = "waiting_user"
|
||||
session.last_error = None
|
||||
session.updated_at = _utcnow()
|
||||
if assistant_result and assistant_result.title and not session.working_title:
|
||||
session.working_title = assistant_result.title
|
||||
@@ -714,6 +965,19 @@ async def _process_pending_turn(
|
||||
message="Story state updated after one turn.",
|
||||
metadata=story_patch,
|
||||
)
|
||||
if safety_flags:
|
||||
await _record_session_event(
|
||||
db,
|
||||
session_id=session.id,
|
||||
turn_id=turn.id,
|
||||
event_type="safety_intervention_requested",
|
||||
status="rewritten",
|
||||
message="Assistant output was rewritten to keep the story child-friendly.",
|
||||
metadata={
|
||||
"stage": "assistant_output",
|
||||
"safety_flags": safety_flags,
|
||||
},
|
||||
)
|
||||
await _record_session_event(
|
||||
db,
|
||||
session_id=session.id,
|
||||
@@ -725,6 +989,7 @@ async def _process_pending_turn(
|
||||
"assistant_text_length": len(assistant_text or ""),
|
||||
"working_title": session.working_title,
|
||||
"requires_confirmation": False,
|
||||
"safety_flags": safety_flags,
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
@@ -805,6 +1070,12 @@ async def _process_pending_turn(
|
||||
return turn.status
|
||||
|
||||
|
||||
def _confirmation_resolution_text(action: str) -> str:
|
||||
if action == "retry_recording":
|
||||
return "好的,我们把这一轮先撤回,你可以重新录一遍,我会重新认真听。"
|
||||
return "好的,我们先切换成文本输入。你可以直接在下面把这一轮想法改写清楚,我们再继续讲。"
|
||||
|
||||
|
||||
async def list_voice_sessions_service(
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
@@ -871,6 +1142,84 @@ async def get_latest_active_voice_session_service(
|
||||
)
|
||||
|
||||
|
||||
async def get_voice_session_analytics_service(
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
days: int | None = 30,
|
||||
) -> VoiceSessionAnalyticsResponse:
|
||||
cutoff = None
|
||||
if days is not None:
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=days)
|
||||
|
||||
session_query = select(VoiceSession).where(VoiceSession.user_id == user_id)
|
||||
turn_query = (
|
||||
select(VoiceTurn)
|
||||
.join(VoiceSession, VoiceTurn.session_id == VoiceSession.id)
|
||||
.where(VoiceSession.user_id == user_id)
|
||||
)
|
||||
event_query = (
|
||||
select(VoiceSessionEvent)
|
||||
.join(VoiceSession, VoiceSessionEvent.session_id == VoiceSession.id)
|
||||
.where(VoiceSession.user_id == user_id)
|
||||
)
|
||||
|
||||
if cutoff is not None:
|
||||
session_query = session_query.where(VoiceSession.created_at >= cutoff)
|
||||
turn_query = turn_query.where(VoiceTurn.created_at >= cutoff)
|
||||
event_query = event_query.where(VoiceSessionEvent.created_at >= cutoff)
|
||||
|
||||
sessions = (await db.execute(session_query)).scalars().all()
|
||||
turns = (await db.execute(turn_query)).scalars().all()
|
||||
events = (await db.execute(event_query)).scalars().all()
|
||||
|
||||
total_sessions = len(sessions)
|
||||
active_sessions = sum(
|
||||
1 for session in sessions if session.status in CONTINUABLE_SESSION_STATUSES
|
||||
)
|
||||
finalized_sessions = sum(1 for session in sessions if session.status == "completed")
|
||||
abandoned_sessions = sum(1 for session in sessions if session.status == "abandoned")
|
||||
total_turns = len(turns)
|
||||
successful_turns = sum(1 for turn in turns if _turn_counts_as_success(turn))
|
||||
failed_turns = sum(1 for turn in turns if turn.status == "failed")
|
||||
asr_failures = sum(1 for event in events if event.event_type == "turn_transcription_failed")
|
||||
tts_failures = sum(
|
||||
1
|
||||
for event in events
|
||||
if event.event_type in {"assistant_audio_failed", "assistant_audio_retry_failed"}
|
||||
)
|
||||
low_confidence_turns = sum(
|
||||
1 for event in events if event.event_type == "turn_confirmation_requested"
|
||||
)
|
||||
safety_interventions = sum(
|
||||
1 for event in events if event.event_type == "safety_intervention_requested"
|
||||
)
|
||||
|
||||
turn_success_rate = (
|
||||
round(successful_turns / total_turns, 4) if total_turns else 0.0
|
||||
)
|
||||
finalize_conversion_rate = (
|
||||
round(finalized_sessions / total_sessions, 4) if total_sessions else 0.0
|
||||
)
|
||||
|
||||
return VoiceSessionAnalyticsResponse(
|
||||
window_days=days,
|
||||
total_sessions=total_sessions,
|
||||
active_sessions=active_sessions,
|
||||
finalized_sessions=finalized_sessions,
|
||||
abandoned_sessions=abandoned_sessions,
|
||||
total_turns=total_turns,
|
||||
successful_turns=successful_turns,
|
||||
failed_turns=failed_turns,
|
||||
asr_failures=asr_failures,
|
||||
tts_failures=tts_failures,
|
||||
low_confidence_turns=low_confidence_turns,
|
||||
safety_interventions=safety_interventions,
|
||||
turn_success_rate=turn_success_rate,
|
||||
finalize_conversion_rate=finalize_conversion_rate,
|
||||
)
|
||||
|
||||
|
||||
async def create_voice_session_service(
|
||||
request: VoiceSessionCreateRequest,
|
||||
user_id: str,
|
||||
@@ -1009,6 +1358,7 @@ async def create_voice_turn_from_upload_service(
|
||||
status_code=409,
|
||||
detail="Voice session is not ready for another turn.",
|
||||
)
|
||||
await _ensure_no_pending_confirmation(db, session=session)
|
||||
if not audio_bytes:
|
||||
raise HTTPException(status_code=400, detail="上传音频为空,请重新录音后再试。")
|
||||
if len(audio_bytes) > settings.voice_turn_max_upload_bytes:
|
||||
@@ -1024,12 +1374,32 @@ async def create_voice_turn_from_upload_service(
|
||||
mime_type=mime_type,
|
||||
audio_data=audio_bytes,
|
||||
)
|
||||
transcription = await transcribe_voice_audio(
|
||||
audio_bytes=audio_bytes,
|
||||
file_name=file_name,
|
||||
mime_type=mime_type,
|
||||
transcript_hint=transcript_hint,
|
||||
)
|
||||
try:
|
||||
transcription = await transcribe_voice_audio(
|
||||
audio_bytes=audio_bytes,
|
||||
file_name=file_name,
|
||||
mime_type=mime_type,
|
||||
transcript_hint=transcript_hint,
|
||||
)
|
||||
except HTTPException as exc:
|
||||
session.last_error = str(exc.detail)
|
||||
session.updated_at = _utcnow()
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
await _record_session_event(
|
||||
db,
|
||||
session_id=session.id,
|
||||
turn_id=None,
|
||||
event_type="turn_transcription_failed",
|
||||
status="failed",
|
||||
message="Voice transcription failed before one turn could be created.",
|
||||
metadata={
|
||||
"mime_type": mime_type,
|
||||
"audio_path": user_audio_path,
|
||||
"error": str(exc.detail),
|
||||
},
|
||||
)
|
||||
raise
|
||||
session, turn = await _create_pending_turn(
|
||||
db,
|
||||
session=session,
|
||||
@@ -1083,6 +1453,86 @@ async def retry_voice_turn_service(
|
||||
)
|
||||
|
||||
|
||||
async def resolve_voice_turn_confirmation_service(
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
request: VoiceTurnConfirmRequest,
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
) -> VoiceTurnSummaryResponse:
|
||||
session = await _get_owned_session(db, session_id=session_id, user_id=user_id)
|
||||
turn = await _get_owned_turn(
|
||||
db,
|
||||
session_id=session_id,
|
||||
turn_id=turn_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if turn.turn_index != session.current_turn_index:
|
||||
raise HTTPException(status_code=409, detail="Only the latest turn can be confirmed.")
|
||||
if not _turn_has_pending_confirmation(turn):
|
||||
raise HTTPException(status_code=409, detail="This turn does not need confirmation.")
|
||||
if not turn.user_transcript:
|
||||
raise HTTPException(status_code=409, detail="This turn has no transcript to confirm.")
|
||||
|
||||
patch = dict(turn.story_patch or {})
|
||||
patch["requires_confirmation"] = False
|
||||
patch["confirmation_state"] = "accepted" if request.action == "accept" else request.action
|
||||
patch["confirmation_message"] = None
|
||||
turn.story_patch = patch
|
||||
turn.error_message = None
|
||||
session.last_error = None
|
||||
session.updated_at = _utcnow()
|
||||
|
||||
if request.action == "accept":
|
||||
session.status = "processing_turn"
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
await db.refresh(turn)
|
||||
await _record_session_event(
|
||||
db,
|
||||
session_id=session.id,
|
||||
turn_id=turn.id,
|
||||
event_type="turn_confirmation_accepted",
|
||||
status="succeeded",
|
||||
message=(
|
||||
"Parent confirmed the current interpretation "
|
||||
"and allowed the story to continue."
|
||||
),
|
||||
metadata={"turn_index": turn.turn_index},
|
||||
)
|
||||
await _process_pending_turn(
|
||||
db,
|
||||
session=session,
|
||||
turn=turn,
|
||||
transcript_text=turn.user_transcript,
|
||||
user_id=user_id,
|
||||
)
|
||||
await db.refresh(turn)
|
||||
return _turn_to_summary(turn)
|
||||
|
||||
guidance_text = _confirmation_resolution_text(request.action)
|
||||
turn.assistant_text = guidance_text
|
||||
turn.assistant_audio_path = None
|
||||
turn.assistant_audio_duration_ms = None
|
||||
turn.status = "narrative_ready"
|
||||
session.status = "waiting_user"
|
||||
session.latest_assistant_text = guidance_text
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
await db.refresh(turn)
|
||||
await _record_session_event(
|
||||
db,
|
||||
session_id=session.id,
|
||||
turn_id=turn.id,
|
||||
event_type=f"turn_confirmation_{request.action}",
|
||||
status="succeeded",
|
||||
message="Pending confirmation was resolved without continuing the current transcript.",
|
||||
metadata={"turn_index": turn.turn_index, "action": request.action},
|
||||
)
|
||||
return _turn_to_summary(turn)
|
||||
|
||||
|
||||
async def retry_voice_turn_audio_service(
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
@@ -1202,9 +1652,10 @@ async def finalize_voice_session_service(
|
||||
)
|
||||
|
||||
session = await _get_owned_session(db, session_id=session_id, user_id=user_id)
|
||||
latest_turn = await _get_latest_turn(db, session_id=session.id)
|
||||
if session.status in FINAL_SESSION_STATUSES:
|
||||
raise HTTPException(status_code=409, detail="Voice session is already closed.")
|
||||
if not _session_can_finalize(session):
|
||||
if not _can_finalize_with_latest_turn(session, latest_turn):
|
||||
raise HTTPException(status_code=409, detail="Voice session is not ready to finalize.")
|
||||
|
||||
session.status = "finalizing_story"
|
||||
@@ -1229,9 +1680,19 @@ async def finalize_voice_session_service(
|
||||
if not final_story_text:
|
||||
raise HTTPException(status_code=409, detail="Voice session has no narrative to save.")
|
||||
|
||||
final_title = _build_final_story_title(session)
|
||||
final_summary = _build_final_story_summary(session)
|
||||
story_state = {
|
||||
**story_state,
|
||||
"final_summary": final_summary,
|
||||
"final_title": final_title,
|
||||
}
|
||||
session.story_state = story_state
|
||||
session.working_title = final_title
|
||||
|
||||
story_result = StoryOutput(
|
||||
mode="generated",
|
||||
title=session.working_title or "一起编织的睡前故事",
|
||||
title=final_title,
|
||||
story_text=final_story_text,
|
||||
cover_prompt_suggestion=(
|
||||
(story_state.get("cover_prompt") or "") if request.generate_cover else ""
|
||||
@@ -1246,6 +1707,36 @@ async def finalize_voice_session_service(
|
||||
db=db,
|
||||
)
|
||||
|
||||
generation_job_id: str | None = None
|
||||
if request.generate_cover and story.cover_prompt:
|
||||
try:
|
||||
await generate_story_cover(story.id, user_id, db)
|
||||
await _record_session_event(
|
||||
db,
|
||||
session_id=session.id,
|
||||
turn_id=None,
|
||||
event_type="session_cover_generation_succeeded",
|
||||
status="succeeded",
|
||||
message="Finalized story cover was generated after session save.",
|
||||
metadata={"story_id": story.id},
|
||||
)
|
||||
except HTTPException as exc:
|
||||
await _record_session_event(
|
||||
db,
|
||||
session_id=session.id,
|
||||
turn_id=None,
|
||||
event_type="session_cover_generation_failed",
|
||||
status="failed",
|
||||
message="Finalized story cover generation failed after session save.",
|
||||
metadata={"story_id": story.id, "error": str(exc.detail)},
|
||||
)
|
||||
logger.warning(
|
||||
"voice_session_finalize_cover_failed",
|
||||
session_id=session.id,
|
||||
story_id=story.id,
|
||||
error=str(exc.detail),
|
||||
)
|
||||
|
||||
session.final_story_id = story.id
|
||||
session.status = "completed"
|
||||
session.last_error = None
|
||||
@@ -1260,14 +1751,18 @@ async def finalize_voice_session_service(
|
||||
event_type="session_saved_as_story",
|
||||
status="succeeded",
|
||||
message="Voice session finalized into a story.",
|
||||
metadata={"story_id": story.id},
|
||||
metadata={
|
||||
"story_id": story.id,
|
||||
"final_title": final_title,
|
||||
"final_summary": final_summary,
|
||||
},
|
||||
)
|
||||
|
||||
return VoiceSessionFinalizeResponse(
|
||||
session_id=session.id,
|
||||
status=session.status,
|
||||
story_id=story.id,
|
||||
generation_job_id=None,
|
||||
generation_job_id=generation_job_id,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from fastapi import HTTPException
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from app.core.config import settings
|
||||
@@ -97,6 +98,10 @@ async def test_voice_session_correct_turn_and_finalize_to_story(
|
||||
"app.services.voice_session_service.text_to_speech",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_tts,
|
||||
patch(
|
||||
"app.services.voice_session_service.generate_story_cover",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_generate_cover,
|
||||
):
|
||||
mock_generate.side_effect = [
|
||||
StoryOutput(
|
||||
@@ -113,6 +118,7 @@ async def test_voice_session_correct_turn_and_finalize_to_story(
|
||||
),
|
||||
]
|
||||
mock_tts.side_effect = [b"turn-1-audio", b"turn-2-audio"]
|
||||
mock_generate_cover.return_value = "https://example.com/voice-cover.png"
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
try:
|
||||
@@ -165,6 +171,8 @@ async def test_voice_session_correct_turn_and_finalize_to_story(
|
||||
assert session_data["status"] == "completed"
|
||||
assert session_data["final_story_id"] == story_id
|
||||
assert session_data["can_continue"] is False
|
||||
assert session_data["story_state"]["final_summary"]
|
||||
mock_generate_cover.assert_awaited_once()
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
@@ -328,14 +336,22 @@ async def test_voice_session_low_confidence_turn_requests_confirmation(
|
||||
turn_data = response.json()
|
||||
assert turn_data["status"] == "audio_ready"
|
||||
assert turn_data["requires_confirmation"] is True
|
||||
assert turn_data["confirmation_state"] == "pending"
|
||||
assert turn_data["understanding_summary"].startswith("本轮系统理解为")
|
||||
assert "请家长帮忙确认" in turn_data["confirmation_message"]
|
||||
assert turn_data["assistant_text"] == turn_data["confirmation_message"]
|
||||
|
||||
response = await client.post(
|
||||
f"/api/voice-sessions/{session_id}/turns/fallback",
|
||||
json={"transcript_text": "我要直接继续下一轮"},
|
||||
)
|
||||
assert response.status_code == 409
|
||||
|
||||
response = await client.get(f"/api/voice-sessions/{session_id}")
|
||||
assert response.status_code == 200
|
||||
session_data = response.json()
|
||||
assert session_data["latest_requires_confirmation"] is True
|
||||
assert session_data["latest_confirmation_state"] == "pending"
|
||||
assert "请家长帮忙确认" in session_data["latest_confirmation_message"]
|
||||
assert session_data["can_finalize"] is False
|
||||
assert session_data["story_state"]["narrative_segments"] == []
|
||||
@@ -349,6 +365,305 @@ async def test_voice_session_low_confidence_turn_requests_confirmation(
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
async def test_voice_session_confirmation_accept_continues_original_turn(
|
||||
db_session,
|
||||
auth_token,
|
||||
):
|
||||
async def override_get_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.voice_session_service.generate_story_content",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_generate,
|
||||
patch(
|
||||
"app.services.voice_session_service.text_to_speech",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_tts,
|
||||
patch(
|
||||
"app.services.voice_session_service.transcribe_voice_audio",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_transcribe,
|
||||
):
|
||||
mock_generate.return_value = StoryOutput(
|
||||
mode="generated",
|
||||
title="小恐龙的星光之旅",
|
||||
story_text="小恐龙踩着亮晶晶的石头,朝着会唱歌的山谷慢慢走去。",
|
||||
cover_prompt_suggestion="A glowing little dinosaur walking into a musical valley",
|
||||
)
|
||||
mock_tts.side_effect = [b"confirmation-audio", b"story-audio"]
|
||||
mock_transcribe.return_value = VoiceTranscriptionResult(
|
||||
transcript_text="我想听一个会发光的小恐龙故事",
|
||||
confidence=0.44,
|
||||
provider="openai",
|
||||
)
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
try:
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
client.cookies.set("access_token", auth_token)
|
||||
|
||||
response = await client.post("/api/voice-sessions", json={})
|
||||
session_id = response.json()["id"]
|
||||
|
||||
response = await client.post(
|
||||
f"/api/voice-sessions/{session_id}/turns",
|
||||
files={
|
||||
"audio_file": ("turn.webm", b"fake-webm-audio", "audio/webm"),
|
||||
},
|
||||
)
|
||||
turn_id = response.json()["turn_id"]
|
||||
|
||||
response = await client.post(
|
||||
f"/api/voice-sessions/{session_id}/turns/{turn_id}/confirm",
|
||||
json={"action": "accept"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
turn_data = response.json()
|
||||
assert turn_data["status"] == "audio_ready"
|
||||
assert turn_data["requires_confirmation"] is False
|
||||
assert turn_data["confirmation_state"] == "accepted"
|
||||
assert "小恐龙踩着亮晶晶的石头" in turn_data["assistant_text"]
|
||||
|
||||
response = await client.get(f"/api/voice-sessions/{session_id}")
|
||||
session_data = response.json()
|
||||
assert session_data["latest_requires_confirmation"] is False
|
||||
assert session_data["can_finalize"] is True
|
||||
assert len(session_data["story_state"]["narrative_segments"]) == 1
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
async def test_voice_session_confirmation_switch_to_text_allows_follow_up_turn(
|
||||
db_session,
|
||||
auth_token,
|
||||
):
|
||||
async def override_get_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.voice_session_service.generate_story_content",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_generate,
|
||||
patch(
|
||||
"app.services.voice_session_service.text_to_speech",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_tts,
|
||||
patch(
|
||||
"app.services.voice_session_service.transcribe_voice_audio",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_transcribe,
|
||||
):
|
||||
mock_generate.return_value = StoryOutput(
|
||||
mode="generated",
|
||||
title="文字修正后的故事",
|
||||
story_text="小熊轻轻推开了云朵门,发现里面藏着一座会发光的图书馆。",
|
||||
cover_prompt_suggestion="A little bear opening a glowing cloud library door",
|
||||
)
|
||||
mock_tts.side_effect = [b"confirmation-audio", b"story-audio"]
|
||||
mock_transcribe.return_value = VoiceTranscriptionResult(
|
||||
transcript_text="我想听一个小熊和云朵门的故事",
|
||||
confidence=0.4,
|
||||
provider="openai",
|
||||
)
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
try:
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
client.cookies.set("access_token", auth_token)
|
||||
|
||||
response = await client.post("/api/voice-sessions", json={})
|
||||
session_id = response.json()["id"]
|
||||
|
||||
response = await client.post(
|
||||
f"/api/voice-sessions/{session_id}/turns",
|
||||
files={
|
||||
"audio_file": ("turn.webm", b"fake-webm-audio", "audio/webm"),
|
||||
},
|
||||
)
|
||||
turn_id = response.json()["turn_id"]
|
||||
|
||||
response = await client.post(
|
||||
f"/api/voice-sessions/{session_id}/turns/{turn_id}/confirm",
|
||||
json={"action": "switch_to_text"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["confirmation_state"] == "switch_to_text"
|
||||
|
||||
response = await client.post(
|
||||
f"/api/voice-sessions/{session_id}/turns/fallback",
|
||||
json={"transcript_text": "我想听一个小熊打开云朵门去冒险的故事"},
|
||||
)
|
||||
assert response.status_code == 202
|
||||
|
||||
response = await client.get(f"/api/voice-sessions/{session_id}")
|
||||
session_data = response.json()
|
||||
assert session_data["latest_requires_confirmation"] is False
|
||||
assert session_data["can_finalize"] is True
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
async def test_voice_session_unsafe_transcript_is_redirected_safely(
|
||||
db_session,
|
||||
auth_token,
|
||||
):
|
||||
async def override_get_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
with patch(
|
||||
"app.services.voice_session_service.text_to_speech",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_tts, patch(
|
||||
"app.services.voice_session_service.generate_story_content",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_generate:
|
||||
mock_tts.return_value = b"safe-redirect-audio"
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
try:
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
client.cookies.set("access_token", auth_token)
|
||||
|
||||
response = await client.post("/api/voice-sessions", json={})
|
||||
session_id = response.json()["id"]
|
||||
|
||||
response = await client.post(
|
||||
f"/api/voice-sessions/{session_id}/turns/fallback",
|
||||
json={"transcript_text": "我想听一个拿着炸弹互相打的故事"},
|
||||
)
|
||||
assert response.status_code == 202
|
||||
turn_id = response.json()["turn_id"]
|
||||
|
||||
response = await client.get(
|
||||
f"/api/voice-sessions/{session_id}/turns/{turn_id}"
|
||||
)
|
||||
turn_data = response.json()
|
||||
assert turn_data["safety_blocked"] is True
|
||||
assert "violence" in turn_data["safety_flags"]
|
||||
assert "温柔、安全" in turn_data["assistant_text"]
|
||||
|
||||
response = await client.get(f"/api/voice-sessions/{session_id}")
|
||||
session_data = response.json()
|
||||
assert session_data["story_state"]["narrative_segments"] == []
|
||||
assert "violence" in session_data["latest_safety_flags"]
|
||||
|
||||
mock_generate.assert_not_awaited()
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
async def test_voice_session_analytics_summarize_failures_and_confirmations(
|
||||
db_session,
|
||||
auth_token,
|
||||
):
|
||||
async def override_get_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.voice_session_service.generate_story_content",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_generate,
|
||||
patch(
|
||||
"app.services.voice_session_service.text_to_speech",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_tts,
|
||||
patch(
|
||||
"app.services.voice_session_service.transcribe_voice_audio",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_transcribe,
|
||||
):
|
||||
mock_generate.side_effect = [
|
||||
StoryOutput(
|
||||
mode="generated",
|
||||
title="安全故事",
|
||||
story_text="第一段安全故事。",
|
||||
cover_prompt_suggestion="safe cover",
|
||||
),
|
||||
StoryOutput(
|
||||
mode="generated",
|
||||
title="确认后继续",
|
||||
story_text="第二段确认后顺利继续。",
|
||||
cover_prompt_suggestion="safe cover 2",
|
||||
),
|
||||
]
|
||||
mock_tts.side_effect = [
|
||||
RuntimeError("tts down"),
|
||||
b"confirmation-audio",
|
||||
b"confirmed-story-audio",
|
||||
]
|
||||
mock_transcribe.side_effect = [
|
||||
VoiceTranscriptionResult(
|
||||
transcript_text="我想听一个会发光的小恐龙故事",
|
||||
confidence=0.41,
|
||||
provider="openai",
|
||||
),
|
||||
HTTPException(status_code=503, detail="语音转写服务暂时不可用,请稍后重试。"),
|
||||
]
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
try:
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
client.cookies.set("access_token", auth_token)
|
||||
|
||||
response = await client.post("/api/voice-sessions", json={})
|
||||
session_id = response.json()["id"]
|
||||
|
||||
await client.post(
|
||||
f"/api/voice-sessions/{session_id}/turns/fallback",
|
||||
json={"transcript_text": "先给我一段故事"},
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
f"/api/voice-sessions/{session_id}/turns",
|
||||
files={
|
||||
"audio_file": ("turn.webm", b"fake-webm-audio", "audio/webm"),
|
||||
},
|
||||
)
|
||||
turn_id = response.json()["turn_id"]
|
||||
await client.post(
|
||||
f"/api/voice-sessions/{session_id}/turns/{turn_id}/confirm",
|
||||
json={"action": "accept"},
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
f"/api/voice-sessions/{session_id}/turns",
|
||||
files={
|
||||
"audio_file": ("turn-2.webm", b"fake-webm-audio-2", "audio/webm"),
|
||||
},
|
||||
)
|
||||
assert response.status_code == 503
|
||||
|
||||
await client.post(
|
||||
f"/api/voice-sessions/{session_id}/finalize",
|
||||
json={"save_story": True, "generate_cover": False},
|
||||
)
|
||||
|
||||
response = await client.get("/api/voice-sessions/analytics?days=30")
|
||||
assert response.status_code == 200
|
||||
analytics = response.json()
|
||||
assert analytics["total_sessions"] >= 1
|
||||
assert analytics["successful_turns"] >= 1
|
||||
assert analytics["tts_failures"] >= 1
|
||||
assert analytics["low_confidence_turns"] >= 1
|
||||
assert analytics["asr_failures"] >= 1
|
||||
assert analytics["finalized_sessions"] >= 1
|
||||
assert analytics["finalize_conversion_rate"] > 0
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
async def test_voice_session_list_orders_recent_sessions_first(
|
||||
db_session,
|
||||
auth_token,
|
||||
|
||||
Reference in New Issue
Block a user