From f106f740dd564e00efa48e8e3ef5c597bb0c0f8f Mon Sep 17 00:00:00 2001 From: Yuyan Date: Sun, 19 Apr 2026 22:54:48 +0800 Subject: [PATCH] feat: add voice co-creation session skeleton --- .../0013_add_voice_sessions_phase_a.py | 187 +++++ backend/app/api/voice_sessions.py | 137 ++++ backend/app/core/config.py | 4 + backend/app/db/models.py | 89 +++ backend/app/main.py | 2 + backend/app/schemas/voice_session_schemas.py | 116 +++ backend/app/services/voice_session_service.py | 744 ++++++++++++++++++ backend/app/services/voice_session_storage.py | 48 ++ backend/tests/conftest.py | 12 + backend/tests/test_voice_sessions.py | 201 +++++ 10 files changed, 1540 insertions(+) create mode 100644 backend/alembic/versions/0013_add_voice_sessions_phase_a.py create mode 100644 backend/app/api/voice_sessions.py create mode 100644 backend/app/schemas/voice_session_schemas.py create mode 100644 backend/app/services/voice_session_service.py create mode 100644 backend/app/services/voice_session_storage.py create mode 100644 backend/tests/test_voice_sessions.py diff --git a/backend/alembic/versions/0013_add_voice_sessions_phase_a.py b/backend/alembic/versions/0013_add_voice_sessions_phase_a.py new file mode 100644 index 0000000..a433983 --- /dev/null +++ b/backend/alembic/versions/0013_add_voice_sessions_phase_a.py @@ -0,0 +1,187 @@ +"""add voice co-creation phase a tables + +Revision ID: 0013_add_voice_sessions_phase_a +Revises: 0012_story_text_status +Create Date: 2026-04-19 + +""" + +import sqlalchemy as sa + +from alembic import op + +revision = "0013_add_voice_sessions_phase_a" +down_revision = "0012_story_text_status" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "voice_sessions", + sa.Column("id", sa.String(length=36), nullable=False), + sa.Column("user_id", sa.String(length=255), nullable=False), + sa.Column("child_profile_id", sa.String(length=36), nullable=True), + sa.Column("universe_id", sa.String(length=36), nullable=True), + sa.Column("final_story_id", sa.Integer(), nullable=True), + sa.Column( + "target_mode", + sa.String(length=32), + nullable=False, + server_default="story", + ), + sa.Column( + "status", + sa.String(length=32), + nullable=False, + server_default="draft", + ), + sa.Column( + "current_turn_index", + sa.Integer(), + nullable=False, + server_default="0", + ), + sa.Column("working_title", sa.String(length=255), nullable=True), + sa.Column("story_state", sa.JSON(), nullable=False, server_default="{}"), + sa.Column("latest_user_transcript", sa.Text(), nullable=True), + sa.Column("latest_assistant_text", sa.Text(), nullable=True), + sa.Column("last_error", sa.Text(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["child_profile_id"], + ["child_profiles.id"], + ondelete="SET NULL", + ), + sa.ForeignKeyConstraint( + ["universe_id"], + ["story_universes.id"], + ondelete="SET NULL", + ), + sa.ForeignKeyConstraint( + ["final_story_id"], + ["stories.id"], + ondelete="SET NULL", + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("ix_voice_sessions_user_id", "voice_sessions", ["user_id"]) + op.create_index( + "ix_voice_sessions_child_profile_id", + "voice_sessions", + ["child_profile_id"], + ) + op.create_index("ix_voice_sessions_universe_id", "voice_sessions", ["universe_id"]) + op.create_index( + "ix_voice_sessions_final_story_id", + "voice_sessions", + ["final_story_id"], + ) + op.create_index("ix_voice_sessions_status", "voice_sessions", ["status"]) + op.create_index("ix_voice_sessions_created_at", "voice_sessions", ["created_at"]) + + op.create_table( + "voice_turns", + sa.Column("id", sa.String(length=36), nullable=False), + sa.Column("session_id", sa.String(length=36), nullable=False), + sa.Column("turn_index", sa.Integer(), nullable=False), + sa.Column( + "status", + sa.String(length=32), + nullable=False, + server_default="received", + ), + sa.Column("user_audio_path", sa.String(length=500), nullable=True), + sa.Column("user_audio_mime_type", sa.String(length=100), nullable=True), + sa.Column("user_audio_duration_ms", sa.Integer(), nullable=True), + sa.Column("user_transcript", sa.Text(), nullable=True), + sa.Column("transcript_confidence", sa.Float(), nullable=True), + sa.Column( + "detected_intent", + sa.String(length=32), + nullable=False, + server_default="unknown", + ), + sa.Column("intent_confidence", sa.Float(), nullable=True), + sa.Column("story_patch", sa.JSON(), nullable=False, server_default="{}"), + sa.Column("assistant_text", sa.Text(), nullable=True), + sa.Column("assistant_audio_path", sa.String(length=500), nullable=True), + sa.Column("assistant_audio_duration_ms", sa.Integer(), nullable=True), + sa.Column("error_message", sa.Text(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.ForeignKeyConstraint(["session_id"], ["voice_sessions.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint( + "session_id", + "turn_index", + name="uq_voice_turn_session_turn_index", + ), + ) + op.create_index("ix_voice_turns_session_id", "voice_turns", ["session_id"]) + op.create_index("ix_voice_turns_status", "voice_turns", ["status"]) + op.create_index("ix_voice_turns_created_at", "voice_turns", ["created_at"]) + + op.create_table( + "voice_session_events", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("session_id", sa.String(length=36), nullable=False), + sa.Column("turn_id", sa.String(length=36), nullable=True), + sa.Column("event_type", sa.String(length=64), nullable=False), + sa.Column("status", sa.String(length=32), nullable=False), + sa.Column("message", sa.Text(), nullable=True), + sa.Column("event_metadata", sa.JSON(), nullable=False, server_default="{}"), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.ForeignKeyConstraint(["session_id"], ["voice_sessions.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["turn_id"], ["voice_turns.id"], ondelete="SET NULL"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "ix_voice_session_events_session_id", + "voice_session_events", + ["session_id"], + ) + op.create_index( + "ix_voice_session_events_turn_id", + "voice_session_events", + ["turn_id"], + ) + op.create_index( + "ix_voice_session_events_created_at", + "voice_session_events", + ["created_at"], + ) + + +def downgrade() -> None: + op.drop_index( + "ix_voice_session_events_created_at", + table_name="voice_session_events", + ) + op.drop_index( + "ix_voice_session_events_turn_id", + table_name="voice_session_events", + ) + op.drop_index( + "ix_voice_session_events_session_id", + table_name="voice_session_events", + ) + op.drop_table("voice_session_events") + + op.drop_index("ix_voice_turns_created_at", table_name="voice_turns") + op.drop_index("ix_voice_turns_status", table_name="voice_turns") + op.drop_index("ix_voice_turns_session_id", table_name="voice_turns") + op.drop_table("voice_turns") + + op.drop_index("ix_voice_sessions_created_at", table_name="voice_sessions") + op.drop_index("ix_voice_sessions_status", table_name="voice_sessions") + op.drop_index("ix_voice_sessions_final_story_id", table_name="voice_sessions") + op.drop_index("ix_voice_sessions_universe_id", table_name="voice_sessions") + op.drop_index( + "ix_voice_sessions_child_profile_id", + table_name="voice_sessions", + ) + op.drop_index("ix_voice_sessions_user_id", table_name="voice_sessions") + op.drop_table("voice_sessions") diff --git a/backend/app/api/voice_sessions.py b/backend/app/api/voice_sessions.py new file mode 100644 index 0000000..f0d0604 --- /dev/null +++ b/backend/app/api/voice_sessions.py @@ -0,0 +1,137 @@ +"""Voice co-creation session APIs.""" + +from fastapi import APIRouter, Depends, Response, status +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.deps import require_user +from app.core.rate_limiter import check_rate_limit +from app.db.database import get_db +from app.db.models import User +from app.schemas.voice_session_schemas import ( + VoiceSessionAbandonRequest, + VoiceSessionCreateRequest, + VoiceSessionDetailResponse, + VoiceSessionFinalizeRequest, + VoiceSessionFinalizeResponse, + VoiceSessionSummaryResponse, + VoiceTurnAcceptedResponse, + VoiceTurnCreateFallbackRequest, + VoiceTurnSummaryResponse, +) +from app.services.voice_session_service import ( + abandon_voice_session_service, + create_voice_session_service, + create_voice_turn_from_text_service, + finalize_voice_session_service, + get_voice_session_detail_service, + get_voice_turn_audio_service, + get_voice_turn_service, +) + +router = APIRouter() + +VOICE_SESSION_RATE_LIMIT_WINDOW = 60 +VOICE_SESSION_RATE_LIMIT_REQUESTS = 20 + + +@router.post( + "/voice-sessions", + response_model=VoiceSessionSummaryResponse, + status_code=status.HTTP_201_CREATED, +) +async def create_voice_session( + request: VoiceSessionCreateRequest, + user: User = Depends(require_user), + db: AsyncSession = Depends(get_db), +): + """Create one draft voice co-creation session.""" + await check_rate_limit( + f"voice-session:{user.id}", + VOICE_SESSION_RATE_LIMIT_REQUESTS, + VOICE_SESSION_RATE_LIMIT_WINDOW, + ) + return await create_voice_session_service(request, user.id, db) + + +@router.get("/voice-sessions/{session_id}", response_model=VoiceSessionDetailResponse) +async def get_voice_session( + session_id: str, + user: User = Depends(require_user), + db: AsyncSession = Depends(get_db), +): + """Get one voice co-creation session with recent turns and events.""" + return await get_voice_session_detail_service(session_id, user.id, db) + + +@router.post( + "/voice-sessions/{session_id}/turns/fallback", + response_model=VoiceTurnAcceptedResponse, + status_code=status.HTTP_202_ACCEPTED, +) +async def create_voice_turn_from_text( + session_id: str, + request: VoiceTurnCreateFallbackRequest, + user: User = Depends(require_user), + db: AsyncSession = Depends(get_db), +): + """Create one turn using text fallback before real audio upload is added.""" + await check_rate_limit( + f"voice-turn:{user.id}", + VOICE_SESSION_RATE_LIMIT_REQUESTS, + VOICE_SESSION_RATE_LIMIT_WINDOW, + ) + return await create_voice_turn_from_text_service(session_id, request, user.id, db) + + +@router.get( + "/voice-sessions/{session_id}/turns/{turn_id}", + response_model=VoiceTurnSummaryResponse, +) +async def get_voice_turn( + session_id: str, + turn_id: str, + user: User = Depends(require_user), + db: AsyncSession = Depends(get_db), +): + """Get one processed turn within a voice session.""" + return await get_voice_turn_service(session_id, turn_id, user.id, db) + + +@router.get("/voice-sessions/{session_id}/turns/{turn_id}/audio") +async def get_voice_turn_audio( + session_id: str, + turn_id: str, + user: User = Depends(require_user), + db: AsyncSession = Depends(get_db), +): + """Get synthesized assistant audio for one completed voice turn.""" + audio_bytes = await get_voice_turn_audio_service(session_id, turn_id, user.id, db) + return Response(content=audio_bytes, media_type="audio/mpeg") + + +@router.post( + "/voice-sessions/{session_id}/finalize", + response_model=VoiceSessionFinalizeResponse, +) +async def finalize_voice_session( + session_id: str, + request: VoiceSessionFinalizeRequest, + user: User = Depends(require_user), + db: AsyncSession = Depends(get_db), +): + """Finalize one voice session into a persisted story.""" + return await finalize_voice_session_service(session_id, request, user.id, db) + + +@router.post( + "/voice-sessions/{session_id}/abandon", + response_model=VoiceSessionSummaryResponse, +) +async def abandon_voice_session( + session_id: str, + request: VoiceSessionAbandonRequest, + user: User = Depends(require_user), + db: AsyncSession = Depends(get_db), +): + """Abandon one in-progress voice session without saving it as a story.""" + return await abandon_voice_session_service(session_id, request, user.id, db) diff --git a/backend/app/core/config.py b/backend/app/core/config.py index f7ca3ea..9038aaf 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -66,6 +66,10 @@ class Settings(BaseSettings): "storage/audio", description="Directory for cached story audio files", ) + voice_session_storage_dir: str = Field( + "storage/voice_sessions", + description="Directory for persisted voice co-creation session assets", + ) story_audio_cache_ttl_days: int = Field( 30, description="TTL in days before cached story audio is pruned", diff --git a/backend/app/db/models.py b/backend/app/db/models.py index 36ca92d..15a4519 100644 --- a/backend/app/db/models.py +++ b/backend/app/db/models.py @@ -168,6 +168,95 @@ class GenerationJobEvent(Base): ) +class VoiceSession(Base): + """Voice co-creation session before it is finalized as a formal story.""" + + __tablename__ = "voice_sessions" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid) + user_id: Mapped[str] = mapped_column( + String(255), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True + ) + child_profile_id: Mapped[str | None] = mapped_column( + String(36), ForeignKey("child_profiles.id", ondelete="SET NULL"), nullable=True, index=True + ) + universe_id: Mapped[str | None] = mapped_column( + String(36), ForeignKey("story_universes.id", ondelete="SET NULL"), nullable=True, index=True + ) + final_story_id: Mapped[int | None] = mapped_column( + Integer, ForeignKey("stories.id", ondelete="SET NULL"), nullable=True, index=True + ) + target_mode: Mapped[str] = mapped_column(String(32), nullable=False, default="story") + status: Mapped[str] = mapped_column(String(32), nullable=False, default="draft", index=True) + current_turn_index: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + working_title: Mapped[str | None] = mapped_column(String(255), nullable=True) + story_state: Mapped[dict] = mapped_column(JSON, default=dict) + latest_user_transcript: Mapped[str | None] = mapped_column(Text, nullable=True) + latest_assistant_text: Mapped[str | None] = mapped_column(Text, nullable=True) + last_error: Mapped[str | None] = mapped_column(Text, nullable=True) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), index=True + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), onupdate=func.now() + ) + + +class VoiceTurn(Base): + """One turn of user input and assistant response within a voice session.""" + + __tablename__ = "voice_turns" + __table_args__ = ( + UniqueConstraint("session_id", "turn_index", name="uq_voice_turn_session_turn_index"), + ) + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=_uuid) + session_id: Mapped[str] = mapped_column( + String(36), ForeignKey("voice_sessions.id", ondelete="CASCADE"), nullable=False, index=True + ) + turn_index: Mapped[int] = mapped_column(Integer, nullable=False) + status: Mapped[str] = mapped_column(String(32), nullable=False, default="received", index=True) + user_audio_path: Mapped[str | None] = mapped_column(String(500), nullable=True) + user_audio_mime_type: Mapped[str | None] = mapped_column(String(100), nullable=True) + user_audio_duration_ms: Mapped[int | None] = mapped_column(Integer, nullable=True) + user_transcript: Mapped[str | None] = mapped_column(Text, nullable=True) + transcript_confidence: Mapped[float | None] = mapped_column(Float, nullable=True) + detected_intent: Mapped[str] = mapped_column(String(32), nullable=False, default="unknown") + intent_confidence: Mapped[float | None] = mapped_column(Float, nullable=True) + story_patch: Mapped[dict] = mapped_column(JSON, default=dict) + assistant_text: Mapped[str | None] = mapped_column(Text, nullable=True) + assistant_audio_path: Mapped[str | None] = mapped_column(String(500), nullable=True) + assistant_audio_duration_ms: Mapped[int | None] = mapped_column(Integer, nullable=True) + error_message: Mapped[str | None] = mapped_column(Text, nullable=True) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), index=True + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), onupdate=func.now() + ) + + +class VoiceSessionEvent(Base): + """Append-only event emitted by one voice co-creation session.""" + + __tablename__ = "voice_session_events" + + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + session_id: Mapped[str] = mapped_column( + String(36), ForeignKey("voice_sessions.id", ondelete="CASCADE"), nullable=False, index=True + ) + turn_id: Mapped[str | None] = mapped_column( + String(36), ForeignKey("voice_turns.id", ondelete="SET NULL"), nullable=True, index=True + ) + event_type: Mapped[str] = mapped_column(String(64), nullable=False) + status: Mapped[str] = mapped_column(String(32), nullable=False) + message: Mapped[str | None] = mapped_column(Text, nullable=True) + event_metadata: Mapped[dict] = mapped_column(JSON, default=dict) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), index=True + ) + + class ChildProfile(Base): """Child profile entity.""" diff --git a/backend/app/main.py b/backend/app/main.py index 09ec39c..bc06e23 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -11,6 +11,7 @@ from app.api import ( reading_events, stories, universes, + voice_sessions, ) from app.core.config import settings from app.core.logging import get_logger, setup_logging @@ -67,6 +68,7 @@ app.add_middleware( app.include_router(auth.router, prefix="/auth", tags=["auth"]) app.include_router(stories.router, prefix="/api", tags=["stories"]) +app.include_router(voice_sessions.router, prefix="/api", tags=["voice-sessions"]) app.include_router(profiles.router, prefix="/api", tags=["profiles"]) app.include_router(universes.router, prefix="/api", tags=["universes"]) app.include_router(push_configs.router, prefix="/api", tags=["push-configs"]) diff --git a/backend/app/schemas/voice_session_schemas.py b/backend/app/schemas/voice_session_schemas.py new file mode 100644 index 0000000..10d06dd --- /dev/null +++ b/backend/app/schemas/voice_session_schemas.py @@ -0,0 +1,116 @@ +"""Pydantic schemas for voice co-creation sessions.""" + +from datetime import datetime +from typing import Any, Literal + +from pydantic import BaseModel, Field + +MAX_VOICE_TRANSCRIPT_LENGTH = 1000 +MAX_VOICE_ABORT_REASON_LENGTH = 200 +MAX_VOICE_TURN_DURATION_MS = 90_000 + + +class VoiceSessionCreateRequest(BaseModel): + """Create one draft voice co-creation session.""" + + child_profile_id: str | None = None + universe_id: str | None = None + target_mode: Literal["story"] = Field(default="story") + + +class VoiceTurnCreateFallbackRequest(BaseModel): + """Create one voice turn using text fallback instead of uploaded audio.""" + + transcript_text: str = Field(..., min_length=1, max_length=MAX_VOICE_TRANSCRIPT_LENGTH) + duration_ms: int | None = Field(default=None, ge=1, le=MAX_VOICE_TURN_DURATION_MS) + + +class VoiceSessionFinalizeRequest(BaseModel): + """Finalize one voice session into a persisted story.""" + + save_story: bool = True + generate_cover: bool = True + generate_final_audio: bool = False + + +class VoiceSessionAbandonRequest(BaseModel): + """Explicitly abandon one in-progress session.""" + + reason: str | None = Field(default=None, max_length=MAX_VOICE_ABORT_REASON_LENGTH) + + +class VoiceSessionEventResponse(BaseModel): + """One persisted session event.""" + + id: int + session_id: str + turn_id: str | None = None + event_type: str + status: str + message: str | None = None + event_metadata: dict[str, Any] = Field(default_factory=dict) + created_at: datetime + + +class VoiceTurnSummaryResponse(BaseModel): + """One summarized voice session turn.""" + + id: str + session_id: str + turn_index: int + status: str + user_transcript: str | None = None + transcript_confidence: float | None = None + detected_intent: str + intent_confidence: float | None = None + assistant_text: str | None = None + assistant_audio_ready: bool = False + assistant_audio_url: str | None = None + error_message: str | None = None + created_at: datetime + updated_at: datetime + + +class VoiceSessionSummaryResponse(BaseModel): + """One summarized voice co-creation session.""" + + id: str + child_profile_id: str | None = None + universe_id: str | None = None + final_story_id: int | None = None + target_mode: str + status: str + current_turn_index: int + working_title: str | None = None + story_state: dict[str, Any] = Field(default_factory=dict) + latest_user_transcript: str | None = None + latest_assistant_text: str | None = None + can_continue: bool = False + can_finalize: bool = False + last_error: str | None = None + created_at: datetime + updated_at: datetime + + +class VoiceSessionDetailResponse(VoiceSessionSummaryResponse): + """Detailed voice session payload with recent turns and events.""" + + recent_turns: list[VoiceTurnSummaryResponse] = Field(default_factory=list) + events: list[VoiceSessionEventResponse] = Field(default_factory=list) + + +class VoiceTurnAcceptedResponse(BaseModel): + """Accepted response for one asynchronously processed turn.""" + + turn_id: str + session_id: str + status: str + + +class VoiceSessionFinalizeResponse(BaseModel): + """Finalize response after a session is converted into a story.""" + + session_id: str + status: str + story_id: int | None = None + generation_job_id: str | None = None diff --git a/backend/app/services/voice_session_service.py b/backend/app/services/voice_session_service.py new file mode 100644 index 0000000..c01a63a --- /dev/null +++ b/backend/app/services/voice_session_service.py @@ -0,0 +1,744 @@ +"""Voice co-creation session service.""" + +from __future__ import annotations + +from typing import Any + +from fastapi import HTTPException +from sqlalchemy import desc, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.logging import get_logger +from app.db.models import VoiceSession, VoiceSessionEvent, VoiceTurn +from app.schemas.voice_session_schemas import ( + VoiceSessionAbandonRequest, + VoiceSessionCreateRequest, + VoiceSessionDetailResponse, + VoiceSessionFinalizeRequest, + VoiceSessionFinalizeResponse, + VoiceSessionSummaryResponse, + VoiceTurnAcceptedResponse, + VoiceTurnCreateFallbackRequest, + VoiceTurnSummaryResponse, +) +from app.services.adapters.text.models import StoryOutput +from app.services.memory_service import build_enhanced_memory_context +from app.services.provider_router import generate_story_content, text_to_speech +from app.services.story_service import ( + create_story_from_result, + validate_profile_and_universe, +) +from app.services.voice_session_storage import ( + build_turn_assistant_audio_path, + read_session_audio, + session_audio_exists, + write_session_audio, +) + +logger = get_logger(__name__) + +ACTIVE_SESSION_STATUSES = {"draft", "active", "processing_turn", "waiting_user"} +CONTINUABLE_SESSION_STATUSES = {"draft", "active", "waiting_user"} +FINAL_SESSION_STATUSES = {"completed", "abandoned"} + + +def _default_story_state() -> dict[str, Any]: + return { + "premise": None, + "latest_direction": None, + "cover_prompt": None, + "narrative_segments": [], + "safety_flags": [], + "last_intent": None, + } + + +def _session_can_continue(session: VoiceSession) -> bool: + return session.status in CONTINUABLE_SESSION_STATUSES + + +def _session_can_finalize(session: VoiceSession) -> bool: + segments = list((session.story_state or {}).get("narrative_segments") or []) + return bool(segments) and session.status in {"active", "waiting_user"} + + +def _assistant_audio_url(session_id: str, turn_id: str, audio_path: str | None) -> str | None: + if not session_audio_exists(audio_path): + return None + return f"/api/voice-sessions/{session_id}/turns/{turn_id}/audio" + + +def _turn_to_summary(turn: VoiceTurn) -> VoiceTurnSummaryResponse: + return VoiceTurnSummaryResponse( + id=turn.id, + session_id=turn.session_id, + turn_index=turn.turn_index, + status=turn.status, + user_transcript=turn.user_transcript, + transcript_confidence=turn.transcript_confidence, + detected_intent=turn.detected_intent, + intent_confidence=turn.intent_confidence, + assistant_text=turn.assistant_text, + assistant_audio_ready=session_audio_exists(turn.assistant_audio_path), + assistant_audio_url=_assistant_audio_url( + turn.session_id, + turn.id, + turn.assistant_audio_path, + ), + error_message=turn.error_message, + created_at=turn.created_at, + updated_at=turn.updated_at, + ) + + +def _session_to_summary(session: VoiceSession) -> VoiceSessionSummaryResponse: + return VoiceSessionSummaryResponse( + id=session.id, + child_profile_id=session.child_profile_id, + universe_id=session.universe_id, + final_story_id=session.final_story_id, + target_mode=session.target_mode, + status=session.status, + current_turn_index=session.current_turn_index, + working_title=session.working_title, + story_state=session.story_state or {}, + latest_user_transcript=session.latest_user_transcript, + latest_assistant_text=session.latest_assistant_text, + can_continue=_session_can_continue(session), + can_finalize=_session_can_finalize(session), + last_error=session.last_error, + created_at=session.created_at, + updated_at=session.updated_at, + ) + + +async def _record_session_event( + db: AsyncSession, + *, + session_id: str, + turn_id: str | None, + event_type: str, + status: str, + message: str | None = None, + metadata: dict[str, Any] | None = None, +) -> VoiceSessionEvent: + event = VoiceSessionEvent( + session_id=session_id, + turn_id=turn_id, + event_type=event_type, + status=status, + message=message, + event_metadata=metadata or {}, + ) + db.add(event) + await db.commit() + await db.refresh(event) + return event + + +async def _get_owned_session( + db: AsyncSession, + *, + session_id: str, + user_id: str, +) -> VoiceSession: + result = await db.execute( + select(VoiceSession).where( + VoiceSession.id == session_id, + VoiceSession.user_id == user_id, + ) + ) + session = result.scalar_one_or_none() + if not session: + raise HTTPException(status_code=404, detail="Voice session not found") + return session + + +async def _get_owned_turn( + db: AsyncSession, + *, + session_id: str, + turn_id: str, + user_id: str, +) -> VoiceTurn: + result = await db.execute( + select(VoiceTurn) + .join(VoiceSession, VoiceTurn.session_id == VoiceSession.id) + .where( + VoiceTurn.id == turn_id, + VoiceTurn.session_id == session_id, + VoiceSession.user_id == user_id, + ) + ) + turn = result.scalar_one_or_none() + if not turn: + raise HTTPException(status_code=404, detail="Voice turn not found") + return turn + + +def _detect_intent( + transcript_text: str, + *, + current_turn_index: int, +) -> tuple[str, float]: + normalized = transcript_text.replace(" ", "") + + if any(keyword in normalized for keyword in ("保存故事", "存起来", "保存吧", "保存到故事库")): + return "save_story", 0.95 + if any(keyword in normalized for keyword in ("先到这里", "讲完了", "结束吧", "停在这里")): + return "end_story", 0.88 + if current_turn_index == 0: + return "start_story", 0.82 + if any( + keyword in normalized + for keyword in ( + "不要", + "改成", + "换成", + "我想让", + "让它", + "改一下", + "改一改", + "其实", + ) + ): + return "correct_story", 0.76 + return "continue_story", 0.68 + + +def _recent_story_text(session: VoiceSession) -> str: + story_state = session.story_state or {} + segments = list(story_state.get("narrative_segments") or []) + if not segments: + return "" + return "\n\n".join(segments[-2:]) + + +def _build_generation_prompt( + *, + session: VoiceSession, + transcript_text: str, + intent: str, +) -> str: + recent_story = _recent_story_text(session) + + if intent == "start_story": + return ( + "你是 DreamWeaver 的儿童故事共创助手。" + "请为 3-8 岁儿童写一个温暖、安全、适合继续接龙的故事开头。" + f"孩子刚刚说:{transcript_text}。" + "请只输出一小段自然的中文故事,不要分点,不要解释,不要写“故事开始”。" + ) + + if intent == "correct_story": + return ( + "你是 DreamWeaver 的儿童故事共创助手。" + f"当前故事最近两段如下:{recent_story or '(暂时还没有已讲述内容)'}。" + f"孩子希望修正故事走向:{transcript_text}。" + "请顺着已有内容自然接住这个修改,继续写一小段新故事。" + "不要从头重讲,不要解释规则。" + ) + + return ( + "你是 DreamWeaver 的儿童故事共创助手。" + f"当前故事最近两段如下:{recent_story or '(暂时还没有已讲述内容)'}。" + f"孩子接着说:{transcript_text}。" + "请继续写一小段新的儿童故事内容,让故事自然往下发展。" + "不要分点,不要做旁白说明。" + ) + + +async def _generate_assistant_turn( + db: AsyncSession, + *, + session: VoiceSession, + transcript_text: str, + intent: str, +) -> StoryOutput: + memory_context = await build_enhanced_memory_context( + session.child_profile_id, + session.universe_id, + db, + ) + prompt = _build_generation_prompt( + session=session, + transcript_text=transcript_text, + intent=intent, + ) + return await generate_story_content( + input_type="full_story", + data=prompt, + memory_context=memory_context, + db=db, + user_id=session.user_id, + ) + + +def _merge_story_state( + session: VoiceSession, + *, + transcript_text: str, + intent: str, + assistant_result: StoryOutput | None, +) -> tuple[dict[str, Any], dict[str, Any]]: + current_state = _default_story_state() | (session.story_state or {}) + narrative_segments = list(current_state.get("narrative_segments") or []) + + if intent == "start_story" and not current_state.get("premise"): + current_state["premise"] = transcript_text + + if assistant_result and assistant_result.story_text: + narrative_segments.append(assistant_result.story_text.strip()) + + current_state["narrative_segments"] = narrative_segments + current_state["latest_direction"] = transcript_text + current_state["last_intent"] = intent + if assistant_result and assistant_result.cover_prompt_suggestion: + current_state["cover_prompt"] = assistant_result.cover_prompt_suggestion + + patch = { + "intent": intent, + "transcript_text": transcript_text, + "segment_added": bool(assistant_result and assistant_result.story_text), + "working_title": assistant_result.title if assistant_result else session.working_title, + "cover_prompt": current_state.get("cover_prompt"), + "narrative_segments_count": len(narrative_segments), + } + return current_state, patch + + +async def create_voice_session_service( + request: VoiceSessionCreateRequest, + user_id: str, + db: AsyncSession, +) -> VoiceSessionSummaryResponse: + profile_id, universe_id = await validate_profile_and_universe( + request.child_profile_id, + request.universe_id, + user_id, + db, + ) + + session = VoiceSession( + user_id=user_id, + child_profile_id=profile_id, + universe_id=universe_id, + target_mode=request.target_mode, + status="draft", + story_state=_default_story_state(), + ) + db.add(session) + await db.commit() + await db.refresh(session) + await _record_session_event( + db, + session_id=session.id, + turn_id=None, + event_type="session_created", + status="succeeded", + message="Voice co-creation session created.", + metadata={ + "child_profile_id": session.child_profile_id, + "universe_id": session.universe_id, + "target_mode": session.target_mode, + }, + ) + await db.refresh(session) + return _session_to_summary(session) + + +async def get_voice_session_detail_service( + session_id: str, + user_id: str, + db: AsyncSession, +) -> VoiceSessionDetailResponse: + session = await _get_owned_session(db, session_id=session_id, user_id=user_id) + turns = ( + await db.execute( + select(VoiceTurn) + .where(VoiceTurn.session_id == session.id) + .order_by(desc(VoiceTurn.turn_index)) + .limit(10) + ) + ).scalars().all() + turns = list(reversed(turns)) + + events = ( + await db.execute( + select(VoiceSessionEvent) + .where(VoiceSessionEvent.session_id == session.id) + .order_by(desc(VoiceSessionEvent.id)) + .limit(50) + ) + ).scalars().all() + events = list(reversed(events)) + + summary = _session_to_summary(session) + return VoiceSessionDetailResponse( + **summary.model_dump(), + recent_turns=[_turn_to_summary(turn) for turn in turns], + events=[ + { + "id": event.id, + "session_id": event.session_id, + "turn_id": event.turn_id, + "event_type": event.event_type, + "status": event.status, + "message": event.message, + "event_metadata": event.event_metadata or {}, + "created_at": event.created_at, + } + for event in events + ], + ) + + +async def create_voice_turn_from_text_service( + session_id: str, + request: VoiceTurnCreateFallbackRequest, + user_id: str, + db: AsyncSession, +) -> VoiceTurnAcceptedResponse: + session = await _get_owned_session(db, session_id=session_id, user_id=user_id) + if session.status not in CONTINUABLE_SESSION_STATUSES: + raise HTTPException( + status_code=409, + detail="Voice session is not ready for another turn.", + ) + + transcript_text = request.transcript_text.strip() + next_turn_index = session.current_turn_index + 1 + detected_intent, intent_confidence = _detect_intent( + transcript_text, + current_turn_index=session.current_turn_index, + ) + + turn = VoiceTurn( + session_id=session.id, + turn_index=next_turn_index, + status="transcribing", + user_audio_duration_ms=request.duration_ms, + user_transcript=transcript_text, + transcript_confidence=1.0, + detected_intent=detected_intent, + intent_confidence=intent_confidence, + ) + session.status = "processing_turn" + session.current_turn_index = next_turn_index + session.latest_user_transcript = transcript_text + session.last_error = None + db.add(turn) + await db.commit() + await db.refresh(session) + await db.refresh(turn) + + await _record_session_event( + db, + session_id=session.id, + turn_id=turn.id, + event_type="turn_received", + status="received", + message="Voice turn fallback text received.", + metadata={"turn_index": turn.turn_index}, + ) + await _record_session_event( + db, + session_id=session.id, + turn_id=turn.id, + event_type="turn_transcribed", + status="succeeded", + message="Fallback transcript accepted.", + metadata={"transcript_confidence": turn.transcript_confidence}, + ) + + assistant_text: str | None = None + assistant_result: StoryOutput | None = None + + try: + await _record_session_event( + db, + session_id=session.id, + turn_id=turn.id, + event_type="intent_resolved", + status="succeeded", + message="Turn intent resolved.", + metadata={ + "detected_intent": detected_intent, + "intent_confidence": intent_confidence, + }, + ) + + if detected_intent == "save_story": + assistant_text = "好的,这个故事已经准备好保存到故事库了。" + elif detected_intent == "end_story": + assistant_text = "好的,我们先把故事停在这里。想保存的话,现在就可以保存到故事库。" + else: + assistant_result = await _generate_assistant_turn( + db, + session=session, + transcript_text=transcript_text, + intent=detected_intent, + ) + assistant_text = assistant_result.story_text.strip() + + merged_state, story_patch = _merge_story_state( + session, + transcript_text=transcript_text, + intent=detected_intent, + assistant_result=assistant_result, + ) + turn.story_patch = story_patch + turn.assistant_text = assistant_text + turn.status = "narrative_ready" + session.story_state = merged_state + session.latest_assistant_text = assistant_text + session.status = "waiting_user" + if assistant_result and assistant_result.title and not session.working_title: + session.working_title = assistant_result.title + await db.commit() + await db.refresh(session) + await db.refresh(turn) + + await _record_session_event( + db, + session_id=session.id, + turn_id=turn.id, + event_type="story_patch_applied", + status="succeeded", + message="Story state updated after one turn.", + metadata=story_patch, + ) + await _record_session_event( + db, + session_id=session.id, + turn_id=turn.id, + event_type="assistant_text_ready", + status="succeeded", + message="Assistant text response generated.", + metadata={ + "assistant_text_length": len(assistant_text or ""), + "working_title": session.working_title, + }, + ) + except Exception as exc: + turn.status = "failed" + turn.error_message = str(exc) + session.status = "waiting_user" + session.last_error = str(exc) + await db.commit() + await db.refresh(session) + await db.refresh(turn) + await _record_session_event( + db, + session_id=session.id, + turn_id=turn.id, + event_type="session_failed", + status="failed", + message="Assistant narrative generation failed for one voice turn.", + metadata={"error": str(exc), "turn_index": turn.turn_index}, + ) + logger.warning( + "voice_turn_generation_failed", + session_id=session.id, + turn_id=turn.id, + error=str(exc), + ) + return VoiceTurnAcceptedResponse( + turn_id=turn.id, + session_id=session.id, + status=turn.status, + ) + + if assistant_text: + try: + audio_bytes = await text_to_speech( + assistant_text, + db=db, + user_id=user_id, + ) + saved_path = write_session_audio( + build_turn_assistant_audio_path(session.id, turn.turn_index), + audio_bytes, + ) + turn.assistant_audio_path = saved_path + turn.assistant_audio_duration_ms = None + turn.status = "audio_ready" + await db.commit() + await db.refresh(turn) + await _record_session_event( + db, + session_id=session.id, + turn_id=turn.id, + event_type="assistant_audio_ready", + status="succeeded", + message="Assistant audio response generated.", + metadata={"audio_path": saved_path}, + ) + except Exception as exc: + turn.status = "narrative_ready" + turn.error_message = None + session.last_error = None + await db.commit() + await db.refresh(turn) + await db.refresh(session) + await _record_session_event( + db, + session_id=session.id, + turn_id=turn.id, + event_type="assistant_audio_failed", + status="failed", + message="Assistant audio generation failed, text response kept.", + metadata={"error": str(exc)}, + ) + logger.warning( + "voice_turn_audio_failed", + session_id=session.id, + turn_id=turn.id, + error=str(exc), + ) + + return VoiceTurnAcceptedResponse( + turn_id=turn.id, + session_id=session.id, + status=turn.status, + ) + + +async def get_voice_turn_service( + session_id: str, + turn_id: str, + user_id: str, + db: AsyncSession, +) -> VoiceTurnSummaryResponse: + turn = await _get_owned_turn( + db, + session_id=session_id, + turn_id=turn_id, + user_id=user_id, + ) + return _turn_to_summary(turn) + + +async def get_voice_turn_audio_service( + session_id: str, + turn_id: str, + user_id: str, + db: AsyncSession, +) -> bytes: + turn = await _get_owned_turn( + db, + session_id=session_id, + turn_id=turn_id, + user_id=user_id, + ) + if not session_audio_exists(turn.assistant_audio_path): + raise HTTPException(status_code=404, detail="Voice turn audio not found") + return read_session_audio(turn.assistant_audio_path) + + +async def finalize_voice_session_service( + session_id: str, + request: VoiceSessionFinalizeRequest, + user_id: str, + db: AsyncSession, +) -> VoiceSessionFinalizeResponse: + if not request.save_story: + raise HTTPException( + status_code=400, + detail="Voice session finalize requires save_story=true in Phase A.", + ) + + session = await _get_owned_session(db, session_id=session_id, user_id=user_id) + if session.status in FINAL_SESSION_STATUSES: + raise HTTPException(status_code=409, detail="Voice session is already closed.") + if not _session_can_finalize(session): + raise HTTPException(status_code=409, detail="Voice session is not ready to finalize.") + + session.status = "finalizing_story" + await db.commit() + await db.refresh(session) + await _record_session_event( + db, + session_id=session.id, + turn_id=None, + event_type="session_finalizing", + status="running", + message="Voice session is being finalized into a story.", + metadata={"generate_cover": request.generate_cover}, + ) + + story_state = session.story_state or {} + narrative_segments = list(story_state.get("narrative_segments") or []) + final_story_text = "\n\n".join( + segment.strip() for segment in narrative_segments if segment.strip() + ) + if not final_story_text: + raise HTTPException(status_code=409, detail="Voice session has no narrative to save.") + + story_result = StoryOutput( + mode="generated", + title=session.working_title or "一起编织的睡前故事", + story_text=final_story_text, + cover_prompt_suggestion=( + (story_state.get("cover_prompt") or "") if request.generate_cover else "" + ), + ) + + story = await create_story_from_result( + result=story_result, + user_id=user_id, + profile_id=session.child_profile_id, + universe_id=session.universe_id, + db=db, + ) + + session.final_story_id = story.id + session.status = "completed" + session.last_error = None + await db.commit() + await db.refresh(session) + + await _record_session_event( + db, + session_id=session.id, + turn_id=None, + event_type="session_saved_as_story", + status="succeeded", + message="Voice session finalized into a story.", + metadata={"story_id": story.id}, + ) + + return VoiceSessionFinalizeResponse( + session_id=session.id, + status=session.status, + story_id=story.id, + generation_job_id=None, + ) + + +async def abandon_voice_session_service( + session_id: str, + request: VoiceSessionAbandonRequest, + user_id: str, + db: AsyncSession, +) -> VoiceSessionSummaryResponse: + session = await _get_owned_session(db, session_id=session_id, user_id=user_id) + if session.status in FINAL_SESSION_STATUSES: + raise HTTPException(status_code=409, detail="Voice session is already closed.") + + session.status = "abandoned" + session.last_error = request.reason + await db.commit() + await db.refresh(session) + + await _record_session_event( + db, + session_id=session.id, + turn_id=None, + event_type="session_abandoned", + status="succeeded", + message="Voice session abandoned by the user.", + metadata={"reason": request.reason}, + ) + await db.refresh(session) + return _session_to_summary(session) diff --git a/backend/app/services/voice_session_storage.py b/backend/app/services/voice_session_storage.py new file mode 100644 index 0000000..5aa3e03 --- /dev/null +++ b/backend/app/services/voice_session_storage.py @@ -0,0 +1,48 @@ +"""Voice co-creation session storage helpers.""" + +from __future__ import annotations + +from pathlib import Path + +from app.core.config import settings + + +def session_storage_dir(session_id: str) -> Path: + """Return the storage directory for one voice session.""" + + return Path(settings.voice_session_storage_dir) / session_id + + +def build_turn_user_audio_path(session_id: str, turn_index: int, suffix: str) -> Path: + """Build the persisted path for one user-uploaded turn audio file.""" + + normalized_suffix = suffix.lstrip(".") or "webm" + return session_storage_dir(session_id) / f"turn-{turn_index:03d}-user.{normalized_suffix}" + + +def build_turn_assistant_audio_path(session_id: str, turn_index: int) -> Path: + """Build the persisted path for one generated assistant turn audio file.""" + + return session_storage_dir(session_id) / f"turn-{turn_index:03d}-assistant.mp3" + + +def write_session_audio(path: Path, audio_data: bytes) -> str: + """Persist session audio bytes atomically and return the saved path.""" + + path.parent.mkdir(parents=True, exist_ok=True) + temp_path = path.with_suffix(f"{path.suffix}.tmp") + temp_path.write_bytes(audio_data) + temp_path.replace(path) + return str(path) + + +def read_session_audio(audio_path: str) -> bytes: + """Read persisted session audio bytes.""" + + return Path(audio_path).read_bytes() + + +def session_audio_exists(audio_path: str | None) -> bool: + """Whether one stored session audio file currently exists.""" + + return bool(audio_path) and Path(audio_path).is_file() diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 02d7ca9..ff216c4 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -187,6 +187,18 @@ def isolated_story_audio_cache(tmp_path, monkeypatch): yield +@pytest.fixture(autouse=True) +def isolated_voice_session_storage(tmp_path, monkeypatch): + """Use an isolated directory for voice session assets.""" + + monkeypatch.setattr( + settings, + "voice_session_storage_dir", + str(tmp_path / "voice_sessions"), + ) + yield + + @pytest.fixture def mock_text_provider(): """Mock text generation.""" diff --git a/backend/tests/test_voice_sessions.py b/backend/tests/test_voice_sessions.py new file mode 100644 index 0000000..30fe708 --- /dev/null +++ b/backend/tests/test_voice_sessions.py @@ -0,0 +1,201 @@ +from unittest.mock import AsyncMock, patch + +from httpx import ASGITransport, AsyncClient + +from app.db.database import get_db +from app.main import app +from app.services.adapters.text.models import StoryOutput + + +async def test_voice_session_create_and_fallback_turn_returns_audio( + db_session, + auth_token, +): + async def override_get_db(): + yield db_session + + app.dependency_overrides[get_db] = override_get_db + + with ( + patch( + "app.services.voice_session_service.generate_story_content", + new_callable=AsyncMock, + ) as mock_generate, + patch( + "app.services.voice_session_service.text_to_speech", + new_callable=AsyncMock, + ) as mock_tts, + ): + mock_generate.return_value = StoryOutput( + mode="generated", + title="小猫去太空", + story_text="小猫跳上纸飞机,朝着月亮轻轻挥手。", + cover_prompt_suggestion="温暖儿童绘本封面,小猫与月亮", + ) + mock_tts.return_value = b"fake-turn-audio" + + transport = ASGITransport(app=app) + try: + async with AsyncClient(transport=transport, base_url="http://test") as client: + client.cookies.set("access_token", auth_token) + + response = await client.post("/api/voice-sessions", json={}) + assert response.status_code == 201 + session_id = response.json()["id"] + + response = await client.post( + f"/api/voice-sessions/{session_id}/turns/fallback", + json={"transcript_text": "我想听一个小猫去太空的故事"}, + ) + assert response.status_code == 202 + turn_id = response.json()["turn_id"] + + response = await client.get( + f"/api/voice-sessions/{session_id}/turns/{turn_id}" + ) + assert response.status_code == 200 + turn_data = response.json() + assert turn_data["status"] == "audio_ready" + assert turn_data["detected_intent"] == "start_story" + assert turn_data["assistant_audio_ready"] is True + assert turn_data["assistant_audio_url"].endswith("/audio") + + response = await client.get(turn_data["assistant_audio_url"]) + assert response.status_code == 200 + assert response.content == b"fake-turn-audio" + assert response.headers["content-type"] == "audio/mpeg" + + response = await client.get(f"/api/voice-sessions/{session_id}") + assert response.status_code == 200 + session_data = response.json() + assert session_data["status"] == "waiting_user" + assert session_data["working_title"] == "小猫去太空" + assert session_data["can_continue"] is True + assert session_data["can_finalize"] is True + assert len(session_data["recent_turns"]) == 1 + finally: + app.dependency_overrides.clear() + + +async def test_voice_session_correct_turn_and_finalize_to_story( + db_session, + auth_token, +): + async def override_get_db(): + yield db_session + + app.dependency_overrides[get_db] = override_get_db + + with ( + patch( + "app.services.voice_session_service.generate_story_content", + new_callable=AsyncMock, + ) as mock_generate, + patch( + "app.services.voice_session_service.text_to_speech", + new_callable=AsyncMock, + ) as mock_tts, + ): + mock_generate.side_effect = [ + StoryOutput( + mode="generated", + title="小猫去太空", + story_text="第一段故事:小猫坐着纸飞机飞向月亮。", + cover_prompt_suggestion="温暖儿童绘本封面,小猫飞向月亮", + ), + StoryOutput( + mode="generated", + title="小猫去太空", + story_text="第二段故事:它在月亮上遇见了会发光的新朋友。", + cover_prompt_suggestion="温暖儿童绘本封面,小猫与月亮朋友", + ), + ] + mock_tts.side_effect = [b"turn-1-audio", b"turn-2-audio"] + + transport = ASGITransport(app=app) + try: + async with AsyncClient(transport=transport, base_url="http://test") as client: + client.cookies.set("access_token", auth_token) + + response = await client.post("/api/voice-sessions", json={}) + assert response.status_code == 201 + session_id = response.json()["id"] + + response = await client.post( + f"/api/voice-sessions/{session_id}/turns/fallback", + json={"transcript_text": "我想听一个小猫去太空的故事"}, + ) + assert response.status_code == 202 + + response = await client.post( + f"/api/voice-sessions/{session_id}/turns/fallback", + json={"transcript_text": "不要让它哭了,我想让它找到一个朋友"}, + ) + assert response.status_code == 202 + turn_id = response.json()["turn_id"] + + response = await client.get( + f"/api/voice-sessions/{session_id}/turns/{turn_id}" + ) + assert response.status_code == 200 + assert response.json()["detected_intent"] == "correct_story" + + response = await client.post( + f"/api/voice-sessions/{session_id}/finalize", + json={"save_story": True, "generate_cover": True}, + ) + assert response.status_code == 200 + finalize_data = response.json() + story_id = finalize_data["story_id"] + assert finalize_data["status"] == "completed" + + response = await client.get(f"/api/stories/{story_id}") + assert response.status_code == 200 + story_data = response.json() + assert story_data["title"] == "小猫去太空" + assert "第一段故事" in story_data["story_text"] + assert "第二段故事" in story_data["story_text"] + assert story_data["generation_status"] == "partial_ready" + + response = await client.get(f"/api/voice-sessions/{session_id}") + assert response.status_code == 200 + session_data = response.json() + assert session_data["status"] == "completed" + assert session_data["final_story_id"] == story_id + assert session_data["can_continue"] is False + finally: + app.dependency_overrides.clear() + + +async def test_voice_session_abandon_blocks_future_turns( + db_session, + auth_token, +): + async def override_get_db(): + yield db_session + + app.dependency_overrides[get_db] = override_get_db + + transport = ASGITransport(app=app) + try: + async with AsyncClient(transport=transport, base_url="http://test") as client: + client.cookies.set("access_token", auth_token) + + response = await client.post("/api/voice-sessions", json={}) + assert response.status_code == 201 + session_id = response.json()["id"] + + response = await client.post( + f"/api/voice-sessions/{session_id}/abandon", + json={"reason": "孩子先去吃饭了"}, + ) + assert response.status_code == 200 + assert response.json()["status"] == "abandoned" + + response = await client.post( + f"/api/voice-sessions/{session_id}/turns/fallback", + json={"transcript_text": "我们继续讲吧"}, + ) + assert response.status_code == 409 + finally: + app.dependency_overrides.clear()