Files
dreamweaver/backend/tests/test_voice_sessions.py

565 lines
21 KiB
Python

from unittest.mock import AsyncMock, patch
from httpx import ASGITransport, AsyncClient
from app.core.config import settings
from app.db.database import get_db
from app.main import app
from app.services.adapters.text.models import StoryOutput
async def test_voice_session_create_and_fallback_turn_returns_audio(
db_session,
auth_token,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
with (
patch(
"app.services.voice_session_service.generate_story_content",
new_callable=AsyncMock,
) as mock_generate,
patch(
"app.services.voice_session_service.text_to_speech",
new_callable=AsyncMock,
) as mock_tts,
):
mock_generate.return_value = StoryOutput(
mode="generated",
title="小猫去太空",
story_text="小猫跳上纸飞机,朝着月亮轻轻挥手。",
cover_prompt_suggestion="温暖儿童绘本封面,小猫与月亮",
)
mock_tts.return_value = b"fake-turn-audio"
transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post("/api/voice-sessions", json={})
assert response.status_code == 201
session_id = response.json()["id"]
response = await client.post(
f"/api/voice-sessions/{session_id}/turns/fallback",
json={"transcript_text": "我想听一个小猫去太空的故事"},
)
assert response.status_code == 202
turn_id = response.json()["turn_id"]
response = await client.get(
f"/api/voice-sessions/{session_id}/turns/{turn_id}"
)
assert response.status_code == 200
turn_data = response.json()
assert turn_data["status"] == "audio_ready"
assert turn_data["detected_intent"] == "start_story"
assert turn_data["assistant_audio_ready"] is True
assert turn_data["assistant_audio_url"].endswith("/audio")
response = await client.get(turn_data["assistant_audio_url"])
assert response.status_code == 200
assert response.content == b"fake-turn-audio"
assert response.headers["content-type"] == "audio/mpeg"
response = await client.get(f"/api/voice-sessions/{session_id}")
assert response.status_code == 200
session_data = response.json()
assert session_data["status"] == "waiting_user"
assert session_data["working_title"] == "小猫去太空"
assert session_data["can_continue"] is True
assert session_data["can_finalize"] is True
assert len(session_data["recent_turns"]) == 1
finally:
app.dependency_overrides.clear()
async def test_voice_session_correct_turn_and_finalize_to_story(
db_session,
auth_token,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
with (
patch(
"app.services.voice_session_service.generate_story_content",
new_callable=AsyncMock,
) as mock_generate,
patch(
"app.services.voice_session_service.text_to_speech",
new_callable=AsyncMock,
) as mock_tts,
):
mock_generate.side_effect = [
StoryOutput(
mode="generated",
title="小猫去太空",
story_text="第一段故事:小猫坐着纸飞机飞向月亮。",
cover_prompt_suggestion="温暖儿童绘本封面,小猫飞向月亮",
),
StoryOutput(
mode="generated",
title="小猫去太空",
story_text="第二段故事:它在月亮上遇见了会发光的新朋友。",
cover_prompt_suggestion="温暖儿童绘本封面,小猫与月亮朋友",
),
]
mock_tts.side_effect = [b"turn-1-audio", b"turn-2-audio"]
transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post("/api/voice-sessions", json={})
assert response.status_code == 201
session_id = response.json()["id"]
response = await client.post(
f"/api/voice-sessions/{session_id}/turns/fallback",
json={"transcript_text": "我想听一个小猫去太空的故事"},
)
assert response.status_code == 202
response = await client.post(
f"/api/voice-sessions/{session_id}/turns/fallback",
json={"transcript_text": "不要让它哭了,我想让它找到一个朋友"},
)
assert response.status_code == 202
turn_id = response.json()["turn_id"]
response = await client.get(
f"/api/voice-sessions/{session_id}/turns/{turn_id}"
)
assert response.status_code == 200
assert response.json()["detected_intent"] == "correct_story"
response = await client.post(
f"/api/voice-sessions/{session_id}/finalize",
json={"save_story": True, "generate_cover": True},
)
assert response.status_code == 200
finalize_data = response.json()
story_id = finalize_data["story_id"]
assert finalize_data["status"] == "completed"
response = await client.get(f"/api/stories/{story_id}")
assert response.status_code == 200
story_data = response.json()
assert story_data["title"] == "小猫去太空"
assert "第一段故事" in story_data["story_text"]
assert "第二段故事" in story_data["story_text"]
assert story_data["generation_status"] == "partial_ready"
response = await client.get(f"/api/voice-sessions/{session_id}")
assert response.status_code == 200
session_data = response.json()
assert session_data["status"] == "completed"
assert session_data["final_story_id"] == story_id
assert session_data["can_continue"] is False
finally:
app.dependency_overrides.clear()
async def test_voice_session_abandon_blocks_future_turns(
db_session,
auth_token,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post("/api/voice-sessions", json={})
assert response.status_code == 201
session_id = response.json()["id"]
response = await client.post(
f"/api/voice-sessions/{session_id}/abandon",
json={"reason": "孩子先去吃饭了"},
)
assert response.status_code == 200
assert response.json()["status"] == "abandoned"
response = await client.post(
f"/api/voice-sessions/{session_id}/turns/fallback",
json={"transcript_text": "我们继续讲吧"},
)
assert response.status_code == 409
finally:
app.dependency_overrides.clear()
async def test_voice_session_uploaded_audio_turn_uses_demo_transcript_hint(
db_session,
auth_token,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
with (
patch(
"app.services.voice_session_service.generate_story_content",
new_callable=AsyncMock,
) as mock_generate,
patch(
"app.services.voice_session_service.text_to_speech",
new_callable=AsyncMock,
) as mock_tts,
):
mock_generate.return_value = StoryOutput(
mode="generated",
title="小鲸鱼找朋友",
story_text="小鲸鱼在海面上遇见了一只会唱歌的海鸥。",
cover_prompt_suggestion="温暖儿童绘本封面,小鲸鱼和海鸥",
)
mock_tts.return_value = b"fake-upload-audio"
transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post("/api/voice-sessions", json={})
assert response.status_code == 201
session_id = response.json()["id"]
response = await client.post(
f"/api/voice-sessions/{session_id}/turns",
files={
"audio_file": ("turn.webm", b"fake-webm-audio", "audio/webm"),
},
data={
"duration_ms": "3200",
"transcript_hint": "我想听一个小鲸鱼找朋友的故事",
},
)
assert response.status_code == 202
turn_data = response.json()
assert turn_data["status"] == "audio_ready"
assert turn_data["transcription_provider"] == "demo"
turn_id = turn_data["turn_id"]
response = await client.get(
f"/api/voice-sessions/{session_id}/turns/{turn_id}"
)
assert response.status_code == 200
detail = response.json()
assert detail["user_audio_ready"] is True
assert detail["user_audio_url"].endswith("/user-audio")
assert detail["transcription_provider"] == "demo"
assert detail["assistant_audio_ready"] is True
response = await client.get(detail["user_audio_url"])
assert response.status_code == 200
assert response.content == b"fake-webm-audio"
assert response.headers["content-type"] == "audio/webm"
finally:
app.dependency_overrides.clear()
async def test_voice_session_list_orders_recent_sessions_first(
db_session,
auth_token,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
with (
patch(
"app.services.voice_session_service.generate_story_content",
new_callable=AsyncMock,
) as mock_generate,
patch(
"app.services.voice_session_service.text_to_speech",
new_callable=AsyncMock,
) as mock_tts,
):
mock_generate.side_effect = [
StoryOutput(
mode="generated",
title="第一场冒险",
story_text="第一段故事。",
cover_prompt_suggestion="封面一",
),
StoryOutput(
mode="generated",
title="第二场冒险",
story_text="第二段故事。",
cover_prompt_suggestion="封面二",
),
]
mock_tts.side_effect = [b"audio-1", b"audio-2"]
transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post("/api/voice-sessions", json={})
first_session_id = response.json()["id"]
await client.post(
f"/api/voice-sessions/{first_session_id}/turns/fallback",
json={"transcript_text": "第一个故事"},
)
response = await client.post("/api/voice-sessions", json={})
second_session_id = response.json()["id"]
await client.post(
f"/api/voice-sessions/{second_session_id}/turns/fallback",
json={"transcript_text": "第二个故事"},
)
response = await client.get("/api/voice-sessions?limit=8")
assert response.status_code == 200
sessions = response.json()
assert len(sessions) >= 2
assert sessions[0]["id"] == second_session_id
assert sessions[1]["id"] == first_session_id
assert sessions[0]["total_turns"] == 1
assert sessions[0]["last_turn_status"] == "audio_ready"
response = await client.get("/api/voice-sessions?active_only=true")
assert response.status_code == 200
active_sessions = response.json()
assert {item["id"] for item in active_sessions} >= {
first_session_id,
second_session_id,
}
finally:
app.dependency_overrides.clear()
async def test_voice_session_active_endpoint_returns_latest_active_session(
db_session,
auth_token,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
with (
patch(
"app.services.voice_session_service.generate_story_content",
new_callable=AsyncMock,
) as mock_generate,
patch(
"app.services.voice_session_service.text_to_speech",
new_callable=AsyncMock,
) as mock_tts,
):
mock_generate.return_value = StoryOutput(
mode="generated",
title="活动会话",
story_text="一段活动中的故事。",
cover_prompt_suggestion="活动会话封面",
)
mock_tts.return_value = b"active-audio"
transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post("/api/voice-sessions", json={})
old_session_id = response.json()["id"]
await client.post(
f"/api/voice-sessions/{old_session_id}/abandon",
json={"reason": "旧会话结束"},
)
response = await client.post("/api/voice-sessions", json={})
active_session_id = response.json()["id"]
await client.post(
f"/api/voice-sessions/{active_session_id}/turns/fallback",
json={"transcript_text": "请继续一个新故事"},
)
response = await client.get("/api/voice-sessions/active")
assert response.status_code == 200
data = response.json()
assert data["id"] == active_session_id
assert data["can_continue"] is True
assert data["status"] == "waiting_user"
finally:
app.dependency_overrides.clear()
async def test_voice_session_can_retry_failed_turn_from_saved_transcript(
db_session,
auth_token,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
with (
patch(
"app.services.voice_session_service.generate_story_content",
new_callable=AsyncMock,
) as mock_generate,
patch(
"app.services.voice_session_service.text_to_speech",
new_callable=AsyncMock,
) as mock_tts,
):
mock_generate.side_effect = [
RuntimeError("provider down"),
StoryOutput(
mode="generated",
title="重试成功",
story_text="重试后的故事终于顺利继续了。",
cover_prompt_suggestion="重试封面",
),
]
mock_tts.return_value = b"retry-turn-audio"
transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post("/api/voice-sessions", json={})
session_id = response.json()["id"]
response = await client.post(
f"/api/voice-sessions/{session_id}/turns/fallback",
json={"transcript_text": "先讲一个会失败的回合"},
)
assert response.status_code == 202
failed_turn_id = response.json()["turn_id"]
response = await client.get(
f"/api/voice-sessions/{session_id}/turns/{failed_turn_id}"
)
assert response.status_code == 200
assert response.json()["status"] == "failed"
response = await client.post(
f"/api/voice-sessions/{session_id}/turns/{failed_turn_id}/retry"
)
assert response.status_code == 202
retried_turn_id = response.json()["turn_id"]
assert retried_turn_id != failed_turn_id
response = await client.get(
f"/api/voice-sessions/{session_id}/turns/{retried_turn_id}"
)
assert response.status_code == 200
retried_turn = response.json()
assert retried_turn["status"] == "audio_ready"
assert retried_turn["assistant_text"] == "重试后的故事终于顺利继续了。"
finally:
app.dependency_overrides.clear()
async def test_voice_session_can_retry_missing_assistant_audio(
db_session,
auth_token,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
with (
patch(
"app.services.voice_session_service.generate_story_content",
new_callable=AsyncMock,
) as mock_generate,
patch(
"app.services.voice_session_service.text_to_speech",
new_callable=AsyncMock,
) as mock_tts,
):
mock_generate.return_value = StoryOutput(
mode="generated",
title="补发语音",
story_text="这一轮先有文本,稍后再补发语音。",
cover_prompt_suggestion="补发语音封面",
)
mock_tts.side_effect = [RuntimeError("tts down"), b"recovered-audio"]
transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post("/api/voice-sessions", json={})
session_id = response.json()["id"]
response = await client.post(
f"/api/voice-sessions/{session_id}/turns/fallback",
json={"transcript_text": "请先给我一段只有文本的结果"},
)
assert response.status_code == 202
turn_id = response.json()["turn_id"]
response = await client.get(
f"/api/voice-sessions/{session_id}/turns/{turn_id}"
)
assert response.status_code == 200
turn = response.json()
assert turn["status"] == "narrative_ready"
assert turn["assistant_audio_ready"] is False
response = await client.post(
f"/api/voice-sessions/{session_id}/turns/{turn_id}/retry-audio"
)
assert response.status_code == 200
retried = response.json()
assert retried["status"] == "audio_ready"
assert retried["assistant_audio_ready"] is True
finally:
app.dependency_overrides.clear()
async def test_voice_session_uploaded_audio_respects_size_limit(
db_session,
auth_token,
monkeypatch,
):
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
monkeypatch.setattr(settings, "voice_turn_max_upload_bytes", 4)
transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url="http://test") as client:
client.cookies.set("access_token", auth_token)
response = await client.post("/api/voice-sessions", json={})
assert response.status_code == 201
session_id = response.json()["id"]
response = await client.post(
f"/api/voice-sessions/{session_id}/turns",
files={
"audio_file": ("turn.webm", b"12345", "audio/webm"),
},
data={"transcript_hint": "太长了"},
)
assert response.status_code == 413
finally:
app.dependency_overrides.clear()