"""Voice co-creation session service.""" from __future__ import annotations from datetime import datetime, timezone from typing import Any from fastapi import HTTPException from sqlalchemy import desc, select from sqlalchemy.ext.asyncio import AsyncSession from app.core.logging import get_logger from app.db.models import VoiceSession, VoiceSessionEvent, VoiceTurn from app.schemas.voice_session_schemas import ( VoiceSessionAbandonRequest, VoiceSessionCreateRequest, VoiceSessionDetailResponse, VoiceSessionFinalizeRequest, VoiceSessionFinalizeResponse, VoiceSessionSummaryResponse, VoiceTurnAcceptedResponse, VoiceTurnCreateFallbackRequest, VoiceTurnSummaryResponse, VoiceTurnUploadAcceptedResponse, ) from app.services.adapters.text.models import StoryOutput from app.services.memory_service import build_enhanced_memory_context from app.services.provider_router import generate_story_content, text_to_speech from app.services.story_service import create_story_from_result, validate_profile_and_universe from app.services.voice_session_storage import ( build_turn_assistant_audio_path, read_session_audio, session_audio_exists, write_session_audio, write_uploaded_user_audio, ) from app.services.voice_transcription_service import transcribe_voice_audio logger = get_logger(__name__) CONTINUABLE_SESSION_STATUSES = {"draft", "active", "waiting_user"} FINAL_SESSION_STATUSES = {"completed", "abandoned"} def _default_story_state() -> dict[str, Any]: return { "premise": None, "latest_direction": None, "cover_prompt": None, "narrative_segments": [], "safety_flags": [], "last_intent": None, } def _session_can_continue(session: VoiceSession) -> bool: return session.status in CONTINUABLE_SESSION_STATUSES def _session_can_finalize(session: VoiceSession) -> bool: segments = list((session.story_state or {}).get("narrative_segments") or []) return bool(segments) and session.status in {"active", "waiting_user"} def _utcnow() -> datetime: return datetime.now(timezone.utc) def _assistant_audio_url(session_id: str, turn_id: str, audio_path: str | None) -> str | None: if not session_audio_exists(audio_path): return None return f"/api/voice-sessions/{session_id}/turns/{turn_id}/audio" def _user_audio_url(session_id: str, turn_id: str, audio_path: str | None) -> str | None: if not session_audio_exists(audio_path): return None return f"/api/voice-sessions/{session_id}/turns/{turn_id}/user-audio" def _turn_to_summary(turn: VoiceTurn) -> VoiceTurnSummaryResponse: turn_patch = turn.story_patch or {} return VoiceTurnSummaryResponse( id=turn.id, session_id=turn.session_id, turn_index=turn.turn_index, status=turn.status, user_transcript=turn.user_transcript, transcript_confidence=turn.transcript_confidence, transcription_provider=turn_patch.get("transcription_provider"), detected_intent=turn.detected_intent, intent_confidence=turn.intent_confidence, assistant_text=turn.assistant_text, assistant_audio_ready=session_audio_exists(turn.assistant_audio_path), assistant_audio_url=_assistant_audio_url( turn.session_id, turn.id, turn.assistant_audio_path, ), user_audio_ready=session_audio_exists(turn.user_audio_path), user_audio_url=_user_audio_url(turn.session_id, turn.id, turn.user_audio_path), error_message=turn.error_message, created_at=turn.created_at, updated_at=turn.updated_at, ) def _session_to_summary( session: VoiceSession, *, latest_turn: VoiceTurn | None = None, total_turns: int | None = None, ) -> VoiceSessionSummaryResponse: if latest_turn is None: total_turns = total_turns if total_turns is not None else session.current_turn_index else: total_turns = total_turns if total_turns is not None else latest_turn.turn_index return VoiceSessionSummaryResponse( id=session.id, child_profile_id=session.child_profile_id, universe_id=session.universe_id, final_story_id=session.final_story_id, target_mode=session.target_mode, status=session.status, current_turn_index=session.current_turn_index, total_turns=total_turns or 0, working_title=session.working_title, story_state=session.story_state or {}, latest_user_transcript=session.latest_user_transcript, latest_assistant_text=session.latest_assistant_text, latest_detected_intent=latest_turn.detected_intent if latest_turn else None, latest_assistant_audio_ready=( session_audio_exists(latest_turn.assistant_audio_path) if latest_turn else False ), last_turn_status=latest_turn.status if latest_turn else None, can_continue=_session_can_continue(session), can_finalize=_session_can_finalize(session), last_error=session.last_error, created_at=session.created_at, updated_at=session.updated_at, ) async def _record_session_event( db: AsyncSession, *, session_id: str, turn_id: str | None, event_type: str, status: str, message: str | None = None, metadata: dict[str, Any] | None = None, ) -> VoiceSessionEvent: event = VoiceSessionEvent( session_id=session_id, turn_id=turn_id, event_type=event_type, status=status, message=message, event_metadata=metadata or {}, ) db.add(event) await db.commit() await db.refresh(event) return event async def _get_owned_session( db: AsyncSession, *, session_id: str, user_id: str, ) -> VoiceSession: result = await db.execute( select(VoiceSession).where( VoiceSession.id == session_id, VoiceSession.user_id == user_id, ) ) session = result.scalar_one_or_none() if not session: raise HTTPException(status_code=404, detail="Voice session not found") return session async def _get_latest_turn( db: AsyncSession, *, session_id: str, ) -> VoiceTurn | None: result = await db.execute( select(VoiceTurn) .where(VoiceTurn.session_id == session_id) .order_by(desc(VoiceTurn.turn_index)) .limit(1) ) return result.scalar_one_or_none() async def _get_owned_turn( db: AsyncSession, *, session_id: str, turn_id: str, user_id: str, ) -> VoiceTurn: result = await db.execute( select(VoiceTurn) .join(VoiceSession, VoiceTurn.session_id == VoiceSession.id) .where( VoiceTurn.id == turn_id, VoiceTurn.session_id == session_id, VoiceSession.user_id == user_id, ) ) turn = result.scalar_one_or_none() if not turn: raise HTTPException(status_code=404, detail="Voice turn not found") return turn def _detect_intent( transcript_text: str, *, current_turn_index: int, ) -> tuple[str, float]: normalized = transcript_text.replace(" ", "") if any(keyword in normalized for keyword in ("保存故事", "存起来", "保存吧", "保存到故事库")): return "save_story", 0.95 if any(keyword in normalized for keyword in ("先到这里", "讲完了", "结束吧", "停在这里")): return "end_story", 0.88 if current_turn_index == 0: return "start_story", 0.82 if any( keyword in normalized for keyword in ( "不要", "改成", "换成", "我想让", "让它", "改一下", "改一改", "其实", ) ): return "correct_story", 0.76 return "continue_story", 0.68 def _recent_story_text(session: VoiceSession) -> str: story_state = session.story_state or {} segments = list(story_state.get("narrative_segments") or []) if not segments: return "" return "\n\n".join(segments[-2:]) def _build_generation_prompt( *, session: VoiceSession, transcript_text: str, intent: str, ) -> str: recent_story = _recent_story_text(session) if intent == "start_story": return ( "你是 DreamWeaver 的儿童故事共创助手。" "请为 3-8 岁儿童写一个温暖、安全、适合继续接龙的故事开头。" f"孩子刚刚说:{transcript_text}。" "请只输出一小段自然的中文故事,不要分点,不要解释,不要写“故事开始”。" ) if intent == "correct_story": return ( "你是 DreamWeaver 的儿童故事共创助手。" f"当前故事最近两段如下:{recent_story or '(暂时还没有已讲述内容)'}。" f"孩子希望修正故事走向:{transcript_text}。" "请顺着已有内容自然接住这个修改,继续写一小段新故事。" "不要从头重讲,不要解释规则。" ) return ( "你是 DreamWeaver 的儿童故事共创助手。" f"当前故事最近两段如下:{recent_story or '(暂时还没有已讲述内容)'}。" f"孩子接着说:{transcript_text}。" "请继续写一小段新的儿童故事内容,让故事自然往下发展。" "不要分点,不要做旁白说明。" ) async def _generate_assistant_turn( db: AsyncSession, *, session: VoiceSession, transcript_text: str, intent: str, ) -> StoryOutput: memory_context = await build_enhanced_memory_context( session.child_profile_id, session.universe_id, db, ) prompt = _build_generation_prompt( session=session, transcript_text=transcript_text, intent=intent, ) return await generate_story_content( input_type="full_story", data=prompt, memory_context=memory_context, db=db, user_id=session.user_id, ) def _merge_story_state( session: VoiceSession, *, transcript_text: str, intent: str, assistant_result: StoryOutput | None, ) -> tuple[dict[str, Any], dict[str, Any]]: current_state = _default_story_state() | (session.story_state or {}) narrative_segments = list(current_state.get("narrative_segments") or []) if intent == "start_story" and not current_state.get("premise"): current_state["premise"] = transcript_text if assistant_result and assistant_result.story_text: narrative_segments.append(assistant_result.story_text.strip()) current_state["narrative_segments"] = narrative_segments current_state["latest_direction"] = transcript_text current_state["last_intent"] = intent if assistant_result and assistant_result.cover_prompt_suggestion: current_state["cover_prompt"] = assistant_result.cover_prompt_suggestion patch = { "intent": intent, "transcript_text": transcript_text, "segment_added": bool(assistant_result and assistant_result.story_text), "working_title": assistant_result.title if assistant_result else session.working_title, "cover_prompt": current_state.get("cover_prompt"), "narrative_segments_count": len(narrative_segments), } return current_state, patch async def _create_pending_turn( db: AsyncSession, *, session: VoiceSession, transcript_text: str, transcript_confidence: float | None, transcription_provider: str | None, user_audio_path: str | None = None, user_audio_mime_type: str | None = None, user_audio_duration_ms: int | None = None, ) -> tuple[VoiceSession, VoiceTurn]: if session.status not in CONTINUABLE_SESSION_STATUSES: raise HTTPException( status_code=409, detail="Voice session is not ready for another turn.", ) next_turn_index = session.current_turn_index + 1 detected_intent, intent_confidence = _detect_intent( transcript_text, current_turn_index=session.current_turn_index, ) turn = VoiceTurn( session_id=session.id, turn_index=next_turn_index, status="transcribing", user_audio_path=user_audio_path, user_audio_mime_type=user_audio_mime_type, user_audio_duration_ms=user_audio_duration_ms, user_transcript=transcript_text, transcript_confidence=transcript_confidence, detected_intent=detected_intent, intent_confidence=intent_confidence, story_patch={"transcription_provider": transcription_provider}, ) session.status = "processing_turn" session.current_turn_index = next_turn_index session.latest_user_transcript = transcript_text session.last_error = None session.updated_at = _utcnow() db.add(turn) await db.commit() await db.refresh(session) await db.refresh(turn) await _record_session_event( db, session_id=session.id, turn_id=turn.id, event_type="turn_received", status="received", message="Voice turn received.", metadata={ "turn_index": turn.turn_index, "has_user_audio": bool(user_audio_path), "transcription_provider": transcription_provider, }, ) if user_audio_path: await _record_session_event( db, session_id=session.id, turn_id=turn.id, event_type="turn_audio_uploaded", status="succeeded", message="User audio uploaded for one voice turn.", metadata={ "mime_type": user_audio_mime_type, "audio_path": user_audio_path, }, ) await _record_session_event( db, session_id=session.id, turn_id=turn.id, event_type="turn_transcribed", status="succeeded", message="Voice turn transcript is available.", metadata={ "transcript_confidence": transcript_confidence, "transcription_provider": transcription_provider, }, ) return session, turn async def _process_pending_turn( db: AsyncSession, *, session: VoiceSession, turn: VoiceTurn, transcript_text: str, user_id: str, ) -> str: assistant_text: str | None = None assistant_result: StoryOutput | None = None detected_intent = turn.detected_intent intent_confidence = turn.intent_confidence try: await _record_session_event( db, session_id=session.id, turn_id=turn.id, event_type="intent_resolved", status="succeeded", message="Turn intent resolved.", metadata={ "detected_intent": detected_intent, "intent_confidence": intent_confidence, }, ) if detected_intent == "save_story": assistant_text = "好的,这个故事已经准备好保存到故事库了。" elif detected_intent == "end_story": assistant_text = "好的,我们先把故事停在这里。想保存的话,现在就可以保存到故事库。" else: assistant_result = await _generate_assistant_turn( db, session=session, transcript_text=transcript_text, intent=detected_intent, ) assistant_text = assistant_result.story_text.strip() merged_state, story_patch = _merge_story_state( session, transcript_text=transcript_text, intent=detected_intent, assistant_result=assistant_result, ) story_patch["transcription_provider"] = ( (turn.story_patch or {}).get("transcription_provider") ) turn.story_patch = story_patch turn.assistant_text = assistant_text turn.status = "narrative_ready" session.story_state = merged_state session.latest_assistant_text = assistant_text session.status = "waiting_user" session.updated_at = _utcnow() if assistant_result and assistant_result.title and not session.working_title: session.working_title = assistant_result.title await db.commit() await db.refresh(session) await db.refresh(turn) await _record_session_event( db, session_id=session.id, turn_id=turn.id, event_type="story_patch_applied", status="succeeded", message="Story state updated after one turn.", metadata=story_patch, ) await _record_session_event( db, session_id=session.id, turn_id=turn.id, event_type="assistant_text_ready", status="succeeded", message="Assistant text response generated.", metadata={ "assistant_text_length": len(assistant_text or ""), "working_title": session.working_title, }, ) except Exception as exc: turn.status = "failed" turn.error_message = str(exc) session.status = "waiting_user" session.last_error = str(exc) session.updated_at = _utcnow() await db.commit() await db.refresh(session) await db.refresh(turn) await _record_session_event( db, session_id=session.id, turn_id=turn.id, event_type="session_failed", status="failed", message="Assistant narrative generation failed for one voice turn.", metadata={"error": str(exc), "turn_index": turn.turn_index}, ) logger.warning( "voice_turn_generation_failed", session_id=session.id, turn_id=turn.id, error=str(exc), ) return turn.status if assistant_text: try: audio_bytes = await text_to_speech( assistant_text, db=db, user_id=user_id, ) saved_path = write_session_audio( build_turn_assistant_audio_path(session.id, turn.turn_index), audio_bytes, ) turn.assistant_audio_path = saved_path turn.assistant_audio_duration_ms = None turn.status = "audio_ready" await db.commit() await db.refresh(turn) await _record_session_event( db, session_id=session.id, turn_id=turn.id, event_type="assistant_audio_ready", status="succeeded", message="Assistant audio response generated.", metadata={"audio_path": saved_path}, ) except Exception as exc: turn.status = "narrative_ready" turn.error_message = None session.last_error = None session.updated_at = _utcnow() await db.commit() await db.refresh(turn) await db.refresh(session) await _record_session_event( db, session_id=session.id, turn_id=turn.id, event_type="assistant_audio_failed", status="failed", message="Assistant audio generation failed, text response kept.", metadata={"error": str(exc)}, ) logger.warning( "voice_turn_audio_failed", session_id=session.id, turn_id=turn.id, error=str(exc), ) return turn.status async def list_voice_sessions_service( user_id: str, db: AsyncSession, *, limit: int = 8, active_only: bool = False, ) -> list[VoiceSessionSummaryResponse]: query = ( select(VoiceSession) .where(VoiceSession.user_id == user_id) .order_by(desc(VoiceSession.updated_at), desc(VoiceSession.created_at)) .limit(limit) ) if active_only: query = query.where(VoiceSession.status.in_(CONTINUABLE_SESSION_STATUSES)) sessions = (await db.execute(query)).scalars().all() summaries: list[VoiceSessionSummaryResponse] = [] for session in sessions: latest_turn = await _get_latest_turn(db, session_id=session.id) summaries.append( _session_to_summary( session, latest_turn=latest_turn, total_turns=session.current_turn_index, ) ) return summaries async def create_voice_session_service( request: VoiceSessionCreateRequest, user_id: str, db: AsyncSession, ) -> VoiceSessionSummaryResponse: profile_id, universe_id = await validate_profile_and_universe( request.child_profile_id, request.universe_id, user_id, db, ) session = VoiceSession( user_id=user_id, child_profile_id=profile_id, universe_id=universe_id, target_mode=request.target_mode, status="draft", story_state=_default_story_state(), ) db.add(session) await db.commit() await db.refresh(session) await _record_session_event( db, session_id=session.id, turn_id=None, event_type="session_created", status="succeeded", message="Voice co-creation session created.", metadata={ "child_profile_id": session.child_profile_id, "universe_id": session.universe_id, "target_mode": session.target_mode, }, ) await db.refresh(session) return _session_to_summary(session) async def get_voice_session_detail_service( session_id: str, user_id: str, db: AsyncSession, ) -> VoiceSessionDetailResponse: session = await _get_owned_session(db, session_id=session_id, user_id=user_id) turns = ( await db.execute( select(VoiceTurn) .where(VoiceTurn.session_id == session.id) .order_by(desc(VoiceTurn.turn_index)) .limit(10) ) ).scalars().all() turns = list(reversed(turns)) events = ( await db.execute( select(VoiceSessionEvent) .where(VoiceSessionEvent.session_id == session.id) .order_by(desc(VoiceSessionEvent.id)) .limit(50) ) ).scalars().all() events = list(reversed(events)) latest_turn = turns[-1] if turns else None summary = _session_to_summary( session, latest_turn=latest_turn, total_turns=session.current_turn_index, ) return VoiceSessionDetailResponse( **summary.model_dump(), recent_turns=[_turn_to_summary(turn) for turn in turns], events=[ { "id": event.id, "session_id": event.session_id, "turn_id": event.turn_id, "event_type": event.event_type, "status": event.status, "message": event.message, "event_metadata": event.event_metadata or {}, "created_at": event.created_at, } for event in events ], ) async def create_voice_turn_from_text_service( session_id: str, request: VoiceTurnCreateFallbackRequest, user_id: str, db: AsyncSession, ) -> VoiceTurnAcceptedResponse: session = await _get_owned_session(db, session_id=session_id, user_id=user_id) transcript_text = request.transcript_text.strip() session, turn = await _create_pending_turn( db, session=session, transcript_text=transcript_text, transcript_confidence=1.0, transcription_provider="fallback", user_audio_duration_ms=request.duration_ms, ) status = await _process_pending_turn( db, session=session, turn=turn, transcript_text=transcript_text, user_id=user_id, ) return VoiceTurnAcceptedResponse( turn_id=turn.id, session_id=session.id, status=status, ) async def create_voice_turn_from_upload_service( *, session_id: str, user_id: str, audio_bytes: bytes, file_name: str, mime_type: str | None, duration_ms: int | None, transcript_hint: str | None, db: AsyncSession, ) -> VoiceTurnUploadAcceptedResponse: session = await _get_owned_session(db, session_id=session_id, user_id=user_id) if session.status not in CONTINUABLE_SESSION_STATUSES: raise HTTPException( status_code=409, detail="Voice session is not ready for another turn.", ) next_turn_index = session.current_turn_index + 1 user_audio_path = write_uploaded_user_audio( session_id=session.id, turn_index=next_turn_index, file_name=file_name, mime_type=mime_type, audio_data=audio_bytes, ) transcription = await transcribe_voice_audio( audio_bytes=audio_bytes, file_name=file_name, mime_type=mime_type, transcript_hint=transcript_hint, ) session, turn = await _create_pending_turn( db, session=session, transcript_text=transcription.transcript_text, transcript_confidence=transcription.confidence, transcription_provider=transcription.provider, user_audio_path=user_audio_path, user_audio_mime_type=mime_type, user_audio_duration_ms=duration_ms, ) status = await _process_pending_turn( db, session=session, turn=turn, transcript_text=transcription.transcript_text, user_id=user_id, ) return VoiceTurnUploadAcceptedResponse( turn_id=turn.id, session_id=session.id, status=status, transcription_provider=transcription.provider, ) async def get_voice_turn_service( session_id: str, turn_id: str, user_id: str, db: AsyncSession, ) -> VoiceTurnSummaryResponse: turn = await _get_owned_turn( db, session_id=session_id, turn_id=turn_id, user_id=user_id, ) return _turn_to_summary(turn) async def get_voice_turn_audio_service( session_id: str, turn_id: str, user_id: str, db: AsyncSession, ) -> bytes: turn = await _get_owned_turn( db, session_id=session_id, turn_id=turn_id, user_id=user_id, ) if not session_audio_exists(turn.assistant_audio_path): raise HTTPException(status_code=404, detail="Voice turn audio not found") return read_session_audio(turn.assistant_audio_path) async def get_voice_turn_user_audio_service( session_id: str, turn_id: str, user_id: str, db: AsyncSession, ) -> tuple[bytes, str]: turn = await _get_owned_turn( db, session_id=session_id, turn_id=turn_id, user_id=user_id, ) if not session_audio_exists(turn.user_audio_path): raise HTTPException(status_code=404, detail="Uploaded user audio not found") return read_session_audio(turn.user_audio_path), (turn.user_audio_mime_type or "audio/webm") async def finalize_voice_session_service( session_id: str, request: VoiceSessionFinalizeRequest, user_id: str, db: AsyncSession, ) -> VoiceSessionFinalizeResponse: if not request.save_story: raise HTTPException( status_code=400, detail="Voice session finalize requires save_story=true in Phase A.", ) session = await _get_owned_session(db, session_id=session_id, user_id=user_id) if session.status in FINAL_SESSION_STATUSES: raise HTTPException(status_code=409, detail="Voice session is already closed.") if not _session_can_finalize(session): raise HTTPException(status_code=409, detail="Voice session is not ready to finalize.") session.status = "finalizing_story" session.updated_at = _utcnow() await db.commit() await db.refresh(session) await _record_session_event( db, session_id=session.id, turn_id=None, event_type="session_finalizing", status="running", message="Voice session is being finalized into a story.", metadata={"generate_cover": request.generate_cover}, ) story_state = session.story_state or {} narrative_segments = list(story_state.get("narrative_segments") or []) final_story_text = "\n\n".join( segment.strip() for segment in narrative_segments if segment.strip() ) if not final_story_text: raise HTTPException(status_code=409, detail="Voice session has no narrative to save.") story_result = StoryOutput( mode="generated", title=session.working_title or "一起编织的睡前故事", story_text=final_story_text, cover_prompt_suggestion=( (story_state.get("cover_prompt") or "") if request.generate_cover else "" ), ) story = await create_story_from_result( result=story_result, user_id=user_id, profile_id=session.child_profile_id, universe_id=session.universe_id, db=db, ) session.final_story_id = story.id session.status = "completed" session.last_error = None session.updated_at = _utcnow() await db.commit() await db.refresh(session) await _record_session_event( db, session_id=session.id, turn_id=None, event_type="session_saved_as_story", status="succeeded", message="Voice session finalized into a story.", metadata={"story_id": story.id}, ) return VoiceSessionFinalizeResponse( session_id=session.id, status=session.status, story_id=story.id, generation_job_id=None, ) async def abandon_voice_session_service( session_id: str, request: VoiceSessionAbandonRequest, user_id: str, db: AsyncSession, ) -> VoiceSessionSummaryResponse: session = await _get_owned_session(db, session_id=session_id, user_id=user_id) if session.status in FINAL_SESSION_STATUSES: raise HTTPException(status_code=409, detail="Voice session is already closed.") session.status = "abandoned" session.last_error = request.reason session.updated_at = _utcnow() await db.commit() await db.refresh(session) await _record_session_event( db, session_id=session.id, turn_id=None, event_type="session_abandoned", status="succeeded", message="Voice session abandoned by the user.", metadata={"reason": request.reason}, ) await db.refresh(session) latest_turn = await _get_latest_turn(db, session_id=session.id) return _session_to_summary( session, latest_turn=latest_turn, total_turns=session.current_turn_index, )