feat: improve voice studio alpha recovery flow
This commit is contained in:
@@ -2,6 +2,7 @@ from unittest.mock import AsyncMock, patch
|
||||
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from app.core.config import settings
|
||||
from app.db.database import get_db
|
||||
from app.main import app
|
||||
from app.services.adapters.text.models import StoryOutput
|
||||
@@ -343,3 +344,221 @@ async def test_voice_session_list_orders_recent_sessions_first(
|
||||
}
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
async def test_voice_session_active_endpoint_returns_latest_active_session(
|
||||
db_session,
|
||||
auth_token,
|
||||
):
|
||||
async def override_get_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.voice_session_service.generate_story_content",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_generate,
|
||||
patch(
|
||||
"app.services.voice_session_service.text_to_speech",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_tts,
|
||||
):
|
||||
mock_generate.return_value = StoryOutput(
|
||||
mode="generated",
|
||||
title="活动会话",
|
||||
story_text="一段活动中的故事。",
|
||||
cover_prompt_suggestion="活动会话封面",
|
||||
)
|
||||
mock_tts.return_value = b"active-audio"
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
try:
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
client.cookies.set("access_token", auth_token)
|
||||
|
||||
response = await client.post("/api/voice-sessions", json={})
|
||||
old_session_id = response.json()["id"]
|
||||
await client.post(
|
||||
f"/api/voice-sessions/{old_session_id}/abandon",
|
||||
json={"reason": "旧会话结束"},
|
||||
)
|
||||
|
||||
response = await client.post("/api/voice-sessions", json={})
|
||||
active_session_id = response.json()["id"]
|
||||
await client.post(
|
||||
f"/api/voice-sessions/{active_session_id}/turns/fallback",
|
||||
json={"transcript_text": "请继续一个新故事"},
|
||||
)
|
||||
|
||||
response = await client.get("/api/voice-sessions/active")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == active_session_id
|
||||
assert data["can_continue"] is True
|
||||
assert data["status"] == "waiting_user"
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
async def test_voice_session_can_retry_failed_turn_from_saved_transcript(
|
||||
db_session,
|
||||
auth_token,
|
||||
):
|
||||
async def override_get_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.voice_session_service.generate_story_content",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_generate,
|
||||
patch(
|
||||
"app.services.voice_session_service.text_to_speech",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_tts,
|
||||
):
|
||||
mock_generate.side_effect = [
|
||||
RuntimeError("provider down"),
|
||||
StoryOutput(
|
||||
mode="generated",
|
||||
title="重试成功",
|
||||
story_text="重试后的故事终于顺利继续了。",
|
||||
cover_prompt_suggestion="重试封面",
|
||||
),
|
||||
]
|
||||
mock_tts.return_value = b"retry-turn-audio"
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
try:
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
client.cookies.set("access_token", auth_token)
|
||||
|
||||
response = await client.post("/api/voice-sessions", json={})
|
||||
session_id = response.json()["id"]
|
||||
|
||||
response = await client.post(
|
||||
f"/api/voice-sessions/{session_id}/turns/fallback",
|
||||
json={"transcript_text": "先讲一个会失败的回合"},
|
||||
)
|
||||
assert response.status_code == 202
|
||||
failed_turn_id = response.json()["turn_id"]
|
||||
|
||||
response = await client.get(
|
||||
f"/api/voice-sessions/{session_id}/turns/{failed_turn_id}"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "failed"
|
||||
|
||||
response = await client.post(
|
||||
f"/api/voice-sessions/{session_id}/turns/{failed_turn_id}/retry"
|
||||
)
|
||||
assert response.status_code == 202
|
||||
retried_turn_id = response.json()["turn_id"]
|
||||
assert retried_turn_id != failed_turn_id
|
||||
|
||||
response = await client.get(
|
||||
f"/api/voice-sessions/{session_id}/turns/{retried_turn_id}"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
retried_turn = response.json()
|
||||
assert retried_turn["status"] == "audio_ready"
|
||||
assert retried_turn["assistant_text"] == "重试后的故事终于顺利继续了。"
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
async def test_voice_session_can_retry_missing_assistant_audio(
|
||||
db_session,
|
||||
auth_token,
|
||||
):
|
||||
async def override_get_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.voice_session_service.generate_story_content",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_generate,
|
||||
patch(
|
||||
"app.services.voice_session_service.text_to_speech",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_tts,
|
||||
):
|
||||
mock_generate.return_value = StoryOutput(
|
||||
mode="generated",
|
||||
title="补发语音",
|
||||
story_text="这一轮先有文本,稍后再补发语音。",
|
||||
cover_prompt_suggestion="补发语音封面",
|
||||
)
|
||||
mock_tts.side_effect = [RuntimeError("tts down"), b"recovered-audio"]
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
try:
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
client.cookies.set("access_token", auth_token)
|
||||
|
||||
response = await client.post("/api/voice-sessions", json={})
|
||||
session_id = response.json()["id"]
|
||||
|
||||
response = await client.post(
|
||||
f"/api/voice-sessions/{session_id}/turns/fallback",
|
||||
json={"transcript_text": "请先给我一段只有文本的结果"},
|
||||
)
|
||||
assert response.status_code == 202
|
||||
turn_id = response.json()["turn_id"]
|
||||
|
||||
response = await client.get(
|
||||
f"/api/voice-sessions/{session_id}/turns/{turn_id}"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
turn = response.json()
|
||||
assert turn["status"] == "narrative_ready"
|
||||
assert turn["assistant_audio_ready"] is False
|
||||
|
||||
response = await client.post(
|
||||
f"/api/voice-sessions/{session_id}/turns/{turn_id}/retry-audio"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
retried = response.json()
|
||||
assert retried["status"] == "audio_ready"
|
||||
assert retried["assistant_audio_ready"] is True
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
async def test_voice_session_uploaded_audio_respects_size_limit(
|
||||
db_session,
|
||||
auth_token,
|
||||
monkeypatch,
|
||||
):
|
||||
async def override_get_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
monkeypatch.setattr(settings, "voice_turn_max_upload_bytes", 4)
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
try:
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
client.cookies.set("access_token", auth_token)
|
||||
|
||||
response = await client.post("/api/voice-sessions", json={})
|
||||
assert response.status_code == 201
|
||||
session_id = response.json()["id"]
|
||||
|
||||
response = await client.post(
|
||||
f"/api/voice-sessions/{session_id}/turns",
|
||||
files={
|
||||
"audio_file": ("turn.webm", b"12345", "audio/webm"),
|
||||
},
|
||||
data={"transcript_hint": "太长了"},
|
||||
)
|
||||
assert response.status_code == 413
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
Reference in New Issue
Block a user