"""Voice co-creation session service.""" from __future__ import annotations from typing import Any from fastapi import HTTPException from sqlalchemy import desc, select from sqlalchemy.ext.asyncio import AsyncSession from app.core.logging import get_logger from app.db.models import VoiceSession, VoiceSessionEvent, VoiceTurn from app.schemas.voice_session_schemas import ( VoiceSessionAbandonRequest, VoiceSessionCreateRequest, VoiceSessionDetailResponse, VoiceSessionFinalizeRequest, VoiceSessionFinalizeResponse, VoiceSessionSummaryResponse, VoiceTurnAcceptedResponse, VoiceTurnCreateFallbackRequest, VoiceTurnSummaryResponse, ) from app.services.adapters.text.models import StoryOutput from app.services.memory_service import build_enhanced_memory_context from app.services.provider_router import generate_story_content, text_to_speech from app.services.story_service import ( create_story_from_result, validate_profile_and_universe, ) from app.services.voice_session_storage import ( build_turn_assistant_audio_path, read_session_audio, session_audio_exists, write_session_audio, ) logger = get_logger(__name__) ACTIVE_SESSION_STATUSES = {"draft", "active", "processing_turn", "waiting_user"} CONTINUABLE_SESSION_STATUSES = {"draft", "active", "waiting_user"} FINAL_SESSION_STATUSES = {"completed", "abandoned"} def _default_story_state() -> dict[str, Any]: return { "premise": None, "latest_direction": None, "cover_prompt": None, "narrative_segments": [], "safety_flags": [], "last_intent": None, } def _session_can_continue(session: VoiceSession) -> bool: return session.status in CONTINUABLE_SESSION_STATUSES def _session_can_finalize(session: VoiceSession) -> bool: segments = list((session.story_state or {}).get("narrative_segments") or []) return bool(segments) and session.status in {"active", "waiting_user"} def _assistant_audio_url(session_id: str, turn_id: str, audio_path: str | None) -> str | None: if not session_audio_exists(audio_path): return None return f"/api/voice-sessions/{session_id}/turns/{turn_id}/audio" def _turn_to_summary(turn: VoiceTurn) -> VoiceTurnSummaryResponse: return VoiceTurnSummaryResponse( id=turn.id, session_id=turn.session_id, turn_index=turn.turn_index, status=turn.status, user_transcript=turn.user_transcript, transcript_confidence=turn.transcript_confidence, detected_intent=turn.detected_intent, intent_confidence=turn.intent_confidence, assistant_text=turn.assistant_text, assistant_audio_ready=session_audio_exists(turn.assistant_audio_path), assistant_audio_url=_assistant_audio_url( turn.session_id, turn.id, turn.assistant_audio_path, ), error_message=turn.error_message, created_at=turn.created_at, updated_at=turn.updated_at, ) def _session_to_summary(session: VoiceSession) -> VoiceSessionSummaryResponse: return VoiceSessionSummaryResponse( id=session.id, child_profile_id=session.child_profile_id, universe_id=session.universe_id, final_story_id=session.final_story_id, target_mode=session.target_mode, status=session.status, current_turn_index=session.current_turn_index, working_title=session.working_title, story_state=session.story_state or {}, latest_user_transcript=session.latest_user_transcript, latest_assistant_text=session.latest_assistant_text, can_continue=_session_can_continue(session), can_finalize=_session_can_finalize(session), last_error=session.last_error, created_at=session.created_at, updated_at=session.updated_at, ) async def _record_session_event( db: AsyncSession, *, session_id: str, turn_id: str | None, event_type: str, status: str, message: str | None = None, metadata: dict[str, Any] | None = None, ) -> VoiceSessionEvent: event = VoiceSessionEvent( session_id=session_id, turn_id=turn_id, event_type=event_type, status=status, message=message, event_metadata=metadata or {}, ) db.add(event) await db.commit() await db.refresh(event) return event async def _get_owned_session( db: AsyncSession, *, session_id: str, user_id: str, ) -> VoiceSession: result = await db.execute( select(VoiceSession).where( VoiceSession.id == session_id, VoiceSession.user_id == user_id, ) ) session = result.scalar_one_or_none() if not session: raise HTTPException(status_code=404, detail="Voice session not found") return session async def _get_owned_turn( db: AsyncSession, *, session_id: str, turn_id: str, user_id: str, ) -> VoiceTurn: result = await db.execute( select(VoiceTurn) .join(VoiceSession, VoiceTurn.session_id == VoiceSession.id) .where( VoiceTurn.id == turn_id, VoiceTurn.session_id == session_id, VoiceSession.user_id == user_id, ) ) turn = result.scalar_one_or_none() if not turn: raise HTTPException(status_code=404, detail="Voice turn not found") return turn def _detect_intent( transcript_text: str, *, current_turn_index: int, ) -> tuple[str, float]: normalized = transcript_text.replace(" ", "") if any(keyword in normalized for keyword in ("保存故事", "存起来", "保存吧", "保存到故事库")): return "save_story", 0.95 if any(keyword in normalized for keyword in ("先到这里", "讲完了", "结束吧", "停在这里")): return "end_story", 0.88 if current_turn_index == 0: return "start_story", 0.82 if any( keyword in normalized for keyword in ( "不要", "改成", "换成", "我想让", "让它", "改一下", "改一改", "其实", ) ): return "correct_story", 0.76 return "continue_story", 0.68 def _recent_story_text(session: VoiceSession) -> str: story_state = session.story_state or {} segments = list(story_state.get("narrative_segments") or []) if not segments: return "" return "\n\n".join(segments[-2:]) def _build_generation_prompt( *, session: VoiceSession, transcript_text: str, intent: str, ) -> str: recent_story = _recent_story_text(session) if intent == "start_story": return ( "你是 DreamWeaver 的儿童故事共创助手。" "请为 3-8 岁儿童写一个温暖、安全、适合继续接龙的故事开头。" f"孩子刚刚说:{transcript_text}。" "请只输出一小段自然的中文故事,不要分点,不要解释,不要写“故事开始”。" ) if intent == "correct_story": return ( "你是 DreamWeaver 的儿童故事共创助手。" f"当前故事最近两段如下:{recent_story or '(暂时还没有已讲述内容)'}。" f"孩子希望修正故事走向:{transcript_text}。" "请顺着已有内容自然接住这个修改,继续写一小段新故事。" "不要从头重讲,不要解释规则。" ) return ( "你是 DreamWeaver 的儿童故事共创助手。" f"当前故事最近两段如下:{recent_story or '(暂时还没有已讲述内容)'}。" f"孩子接着说:{transcript_text}。" "请继续写一小段新的儿童故事内容,让故事自然往下发展。" "不要分点,不要做旁白说明。" ) async def _generate_assistant_turn( db: AsyncSession, *, session: VoiceSession, transcript_text: str, intent: str, ) -> StoryOutput: memory_context = await build_enhanced_memory_context( session.child_profile_id, session.universe_id, db, ) prompt = _build_generation_prompt( session=session, transcript_text=transcript_text, intent=intent, ) return await generate_story_content( input_type="full_story", data=prompt, memory_context=memory_context, db=db, user_id=session.user_id, ) def _merge_story_state( session: VoiceSession, *, transcript_text: str, intent: str, assistant_result: StoryOutput | None, ) -> tuple[dict[str, Any], dict[str, Any]]: current_state = _default_story_state() | (session.story_state or {}) narrative_segments = list(current_state.get("narrative_segments") or []) if intent == "start_story" and not current_state.get("premise"): current_state["premise"] = transcript_text if assistant_result and assistant_result.story_text: narrative_segments.append(assistant_result.story_text.strip()) current_state["narrative_segments"] = narrative_segments current_state["latest_direction"] = transcript_text current_state["last_intent"] = intent if assistant_result and assistant_result.cover_prompt_suggestion: current_state["cover_prompt"] = assistant_result.cover_prompt_suggestion patch = { "intent": intent, "transcript_text": transcript_text, "segment_added": bool(assistant_result and assistant_result.story_text), "working_title": assistant_result.title if assistant_result else session.working_title, "cover_prompt": current_state.get("cover_prompt"), "narrative_segments_count": len(narrative_segments), } return current_state, patch async def create_voice_session_service( request: VoiceSessionCreateRequest, user_id: str, db: AsyncSession, ) -> VoiceSessionSummaryResponse: profile_id, universe_id = await validate_profile_and_universe( request.child_profile_id, request.universe_id, user_id, db, ) session = VoiceSession( user_id=user_id, child_profile_id=profile_id, universe_id=universe_id, target_mode=request.target_mode, status="draft", story_state=_default_story_state(), ) db.add(session) await db.commit() await db.refresh(session) await _record_session_event( db, session_id=session.id, turn_id=None, event_type="session_created", status="succeeded", message="Voice co-creation session created.", metadata={ "child_profile_id": session.child_profile_id, "universe_id": session.universe_id, "target_mode": session.target_mode, }, ) await db.refresh(session) return _session_to_summary(session) async def get_voice_session_detail_service( session_id: str, user_id: str, db: AsyncSession, ) -> VoiceSessionDetailResponse: session = await _get_owned_session(db, session_id=session_id, user_id=user_id) turns = ( await db.execute( select(VoiceTurn) .where(VoiceTurn.session_id == session.id) .order_by(desc(VoiceTurn.turn_index)) .limit(10) ) ).scalars().all() turns = list(reversed(turns)) events = ( await db.execute( select(VoiceSessionEvent) .where(VoiceSessionEvent.session_id == session.id) .order_by(desc(VoiceSessionEvent.id)) .limit(50) ) ).scalars().all() events = list(reversed(events)) summary = _session_to_summary(session) return VoiceSessionDetailResponse( **summary.model_dump(), recent_turns=[_turn_to_summary(turn) for turn in turns], events=[ { "id": event.id, "session_id": event.session_id, "turn_id": event.turn_id, "event_type": event.event_type, "status": event.status, "message": event.message, "event_metadata": event.event_metadata or {}, "created_at": event.created_at, } for event in events ], ) async def create_voice_turn_from_text_service( session_id: str, request: VoiceTurnCreateFallbackRequest, user_id: str, db: AsyncSession, ) -> VoiceTurnAcceptedResponse: session = await _get_owned_session(db, session_id=session_id, user_id=user_id) if session.status not in CONTINUABLE_SESSION_STATUSES: raise HTTPException( status_code=409, detail="Voice session is not ready for another turn.", ) transcript_text = request.transcript_text.strip() next_turn_index = session.current_turn_index + 1 detected_intent, intent_confidence = _detect_intent( transcript_text, current_turn_index=session.current_turn_index, ) turn = VoiceTurn( session_id=session.id, turn_index=next_turn_index, status="transcribing", user_audio_duration_ms=request.duration_ms, user_transcript=transcript_text, transcript_confidence=1.0, detected_intent=detected_intent, intent_confidence=intent_confidence, ) session.status = "processing_turn" session.current_turn_index = next_turn_index session.latest_user_transcript = transcript_text session.last_error = None db.add(turn) await db.commit() await db.refresh(session) await db.refresh(turn) await _record_session_event( db, session_id=session.id, turn_id=turn.id, event_type="turn_received", status="received", message="Voice turn fallback text received.", metadata={"turn_index": turn.turn_index}, ) await _record_session_event( db, session_id=session.id, turn_id=turn.id, event_type="turn_transcribed", status="succeeded", message="Fallback transcript accepted.", metadata={"transcript_confidence": turn.transcript_confidence}, ) assistant_text: str | None = None assistant_result: StoryOutput | None = None try: await _record_session_event( db, session_id=session.id, turn_id=turn.id, event_type="intent_resolved", status="succeeded", message="Turn intent resolved.", metadata={ "detected_intent": detected_intent, "intent_confidence": intent_confidence, }, ) if detected_intent == "save_story": assistant_text = "好的,这个故事已经准备好保存到故事库了。" elif detected_intent == "end_story": assistant_text = "好的,我们先把故事停在这里。想保存的话,现在就可以保存到故事库。" else: assistant_result = await _generate_assistant_turn( db, session=session, transcript_text=transcript_text, intent=detected_intent, ) assistant_text = assistant_result.story_text.strip() merged_state, story_patch = _merge_story_state( session, transcript_text=transcript_text, intent=detected_intent, assistant_result=assistant_result, ) turn.story_patch = story_patch turn.assistant_text = assistant_text turn.status = "narrative_ready" session.story_state = merged_state session.latest_assistant_text = assistant_text session.status = "waiting_user" if assistant_result and assistant_result.title and not session.working_title: session.working_title = assistant_result.title await db.commit() await db.refresh(session) await db.refresh(turn) await _record_session_event( db, session_id=session.id, turn_id=turn.id, event_type="story_patch_applied", status="succeeded", message="Story state updated after one turn.", metadata=story_patch, ) await _record_session_event( db, session_id=session.id, turn_id=turn.id, event_type="assistant_text_ready", status="succeeded", message="Assistant text response generated.", metadata={ "assistant_text_length": len(assistant_text or ""), "working_title": session.working_title, }, ) except Exception as exc: turn.status = "failed" turn.error_message = str(exc) session.status = "waiting_user" session.last_error = str(exc) await db.commit() await db.refresh(session) await db.refresh(turn) await _record_session_event( db, session_id=session.id, turn_id=turn.id, event_type="session_failed", status="failed", message="Assistant narrative generation failed for one voice turn.", metadata={"error": str(exc), "turn_index": turn.turn_index}, ) logger.warning( "voice_turn_generation_failed", session_id=session.id, turn_id=turn.id, error=str(exc), ) return VoiceTurnAcceptedResponse( turn_id=turn.id, session_id=session.id, status=turn.status, ) if assistant_text: try: audio_bytes = await text_to_speech( assistant_text, db=db, user_id=user_id, ) saved_path = write_session_audio( build_turn_assistant_audio_path(session.id, turn.turn_index), audio_bytes, ) turn.assistant_audio_path = saved_path turn.assistant_audio_duration_ms = None turn.status = "audio_ready" await db.commit() await db.refresh(turn) await _record_session_event( db, session_id=session.id, turn_id=turn.id, event_type="assistant_audio_ready", status="succeeded", message="Assistant audio response generated.", metadata={"audio_path": saved_path}, ) except Exception as exc: turn.status = "narrative_ready" turn.error_message = None session.last_error = None await db.commit() await db.refresh(turn) await db.refresh(session) await _record_session_event( db, session_id=session.id, turn_id=turn.id, event_type="assistant_audio_failed", status="failed", message="Assistant audio generation failed, text response kept.", metadata={"error": str(exc)}, ) logger.warning( "voice_turn_audio_failed", session_id=session.id, turn_id=turn.id, error=str(exc), ) return VoiceTurnAcceptedResponse( turn_id=turn.id, session_id=session.id, status=turn.status, ) async def get_voice_turn_service( session_id: str, turn_id: str, user_id: str, db: AsyncSession, ) -> VoiceTurnSummaryResponse: turn = await _get_owned_turn( db, session_id=session_id, turn_id=turn_id, user_id=user_id, ) return _turn_to_summary(turn) async def get_voice_turn_audio_service( session_id: str, turn_id: str, user_id: str, db: AsyncSession, ) -> bytes: turn = await _get_owned_turn( db, session_id=session_id, turn_id=turn_id, user_id=user_id, ) if not session_audio_exists(turn.assistant_audio_path): raise HTTPException(status_code=404, detail="Voice turn audio not found") return read_session_audio(turn.assistant_audio_path) async def finalize_voice_session_service( session_id: str, request: VoiceSessionFinalizeRequest, user_id: str, db: AsyncSession, ) -> VoiceSessionFinalizeResponse: if not request.save_story: raise HTTPException( status_code=400, detail="Voice session finalize requires save_story=true in Phase A.", ) session = await _get_owned_session(db, session_id=session_id, user_id=user_id) if session.status in FINAL_SESSION_STATUSES: raise HTTPException(status_code=409, detail="Voice session is already closed.") if not _session_can_finalize(session): raise HTTPException(status_code=409, detail="Voice session is not ready to finalize.") session.status = "finalizing_story" await db.commit() await db.refresh(session) await _record_session_event( db, session_id=session.id, turn_id=None, event_type="session_finalizing", status="running", message="Voice session is being finalized into a story.", metadata={"generate_cover": request.generate_cover}, ) story_state = session.story_state or {} narrative_segments = list(story_state.get("narrative_segments") or []) final_story_text = "\n\n".join( segment.strip() for segment in narrative_segments if segment.strip() ) if not final_story_text: raise HTTPException(status_code=409, detail="Voice session has no narrative to save.") story_result = StoryOutput( mode="generated", title=session.working_title or "一起编织的睡前故事", story_text=final_story_text, cover_prompt_suggestion=( (story_state.get("cover_prompt") or "") if request.generate_cover else "" ), ) story = await create_story_from_result( result=story_result, user_id=user_id, profile_id=session.child_profile_id, universe_id=session.universe_id, db=db, ) session.final_story_id = story.id session.status = "completed" session.last_error = None await db.commit() await db.refresh(session) await _record_session_event( db, session_id=session.id, turn_id=None, event_type="session_saved_as_story", status="succeeded", message="Voice session finalized into a story.", metadata={"story_id": story.id}, ) return VoiceSessionFinalizeResponse( session_id=session.id, status=session.status, story_id=story.id, generation_job_id=None, ) async def abandon_voice_session_service( session_id: str, request: VoiceSessionAbandonRequest, user_id: str, db: AsyncSession, ) -> VoiceSessionSummaryResponse: session = await _get_owned_session(db, session_id=session_id, user_id=user_id) if session.status in FINAL_SESSION_STATUSES: raise HTTPException(status_code=409, detail="Voice session is already closed.") session.status = "abandoned" session.last_error = request.reason await db.commit() await db.refresh(session) await _record_session_event( db, session_id=session.id, turn_id=None, event_type="session_abandoned", status="succeeded", message="Voice session abandoned by the user.", metadata={"reason": request.reason}, ) await db.refresh(session) return _session_to_summary(session)