feat: persist story generation states and cache audio
Some checks failed
Build and Push Docker Images / changes (push) Has been cancelled
Build and Push Docker Images / build-backend (push) Has been cancelled
Build and Push Docker Images / build-frontend (push) Has been cancelled
Build and Push Docker Images / build-admin-frontend (push) Has been cancelled
Some checks failed
Build and Push Docker Images / changes (push) Has been cancelled
Build and Push Docker Images / build-backend (push) Has been cancelled
Build and Push Docker Images / build-frontend (push) Has been cancelled
Build and Push Docker Images / build-admin-frontend (push) Has been cancelled
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
"""测试配置和 fixtures。"""
|
||||
"""Pytest fixtures for backend tests."""
|
||||
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_asyn
|
||||
os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing")
|
||||
os.environ.setdefault("DATABASE_URL", "sqlite+aiosqlite:///:memory:")
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.security import create_access_token
|
||||
from app.db.database import get_db
|
||||
from app.db.models import Base, Story, User
|
||||
@@ -19,7 +20,8 @@ from app.main import app
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""创建内存数据库引擎。"""
|
||||
"""Create an in-memory database engine."""
|
||||
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
@@ -29,7 +31,8 @@ async def async_engine():
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""创建数据库会话。"""
|
||||
"""Create a database session."""
|
||||
|
||||
session_factory = async_sessionmaker(
|
||||
async_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
@@ -39,7 +42,8 @@ async def db_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
|
||||
@pytest.fixture
|
||||
async def test_user(db_session: AsyncSession) -> User:
|
||||
"""创建测试用户。"""
|
||||
"""Create a test user."""
|
||||
|
||||
user = User(
|
||||
id="github:12345",
|
||||
name="Test User",
|
||||
@@ -54,13 +58,74 @@ async def test_user(db_session: AsyncSession) -> User:
|
||||
|
||||
@pytest.fixture
|
||||
async def test_story(db_session: AsyncSession, test_user: User) -> Story:
|
||||
"""创建测试故事。"""
|
||||
"""Create a plain generated story."""
|
||||
|
||||
story = Story(
|
||||
user_id=test_user.id,
|
||||
title="测试故事",
|
||||
story_text="从前有一只小兔子...",
|
||||
story_text="从前有一只小兔子。",
|
||||
cover_prompt="A cute rabbit in a forest",
|
||||
mode="generated",
|
||||
generation_status="narrative_ready",
|
||||
image_status="not_requested",
|
||||
audio_status="not_requested",
|
||||
)
|
||||
db_session.add(story)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(story)
|
||||
return story
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def storybook_story(db_session: AsyncSession, test_user: User) -> Story:
|
||||
"""Create a storybook-mode story."""
|
||||
|
||||
story = Story(
|
||||
user_id=test_user.id,
|
||||
title="森林绘本冒险",
|
||||
story_text=None,
|
||||
pages=[
|
||||
{
|
||||
"page_number": 1,
|
||||
"text": "小兔子走进了会发光的森林。",
|
||||
"image_prompt": "A glowing forest with a curious rabbit",
|
||||
"image_url": "https://example.com/page-1.png",
|
||||
},
|
||||
{
|
||||
"page_number": 2,
|
||||
"text": "它遇见了一位会唱歌的萤火虫朋友。",
|
||||
"image_prompt": "A rabbit meeting a singing firefly",
|
||||
"image_url": None,
|
||||
},
|
||||
],
|
||||
cover_prompt="A magical forest storybook cover",
|
||||
image_url="https://example.com/storybook-cover.png",
|
||||
mode="storybook",
|
||||
generation_status="degraded_completed",
|
||||
image_status="failed",
|
||||
audio_status="not_requested",
|
||||
last_error="第 2 页插图生成失败",
|
||||
)
|
||||
db_session.add(story)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(story)
|
||||
return story
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def degraded_story_with_text(db_session: AsyncSession, test_user: User) -> Story:
|
||||
"""Create a readable story whose image generation already failed."""
|
||||
|
||||
story = Story(
|
||||
user_id=test_user.id,
|
||||
title="部分完成的测试故事",
|
||||
story_text="从前有一只小兔子继续冒险。",
|
||||
cover_prompt="A rabbit under the moon",
|
||||
mode="generated",
|
||||
generation_status="degraded_completed",
|
||||
image_status="failed",
|
||||
audio_status="not_requested",
|
||||
last_error="封面生成失败",
|
||||
)
|
||||
db_session.add(story)
|
||||
await db_session.commit()
|
||||
@@ -70,13 +135,14 @@ async def test_story(db_session: AsyncSession, test_user: User) -> Story:
|
||||
|
||||
@pytest.fixture
|
||||
def auth_token(test_user: User) -> str:
|
||||
"""生成测试用户的 JWT token。"""
|
||||
"""Create a JWT token for the test user."""
|
||||
|
||||
return create_access_token({"sub": test_user.id})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(db_session: AsyncSession) -> TestClient:
|
||||
"""创建测试客户端。"""
|
||||
"""Create a test client."""
|
||||
|
||||
async def override_get_db():
|
||||
yield db_session
|
||||
@@ -89,35 +155,45 @@ def client(db_session: AsyncSession) -> TestClient:
|
||||
|
||||
@pytest.fixture
|
||||
def auth_client(client: TestClient, auth_token: str) -> TestClient:
|
||||
"""带认证的测试客户端。"""
|
||||
"""Create an authenticated test client."""
|
||||
|
||||
client.cookies.set("access_token", auth_token)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def bypass_rate_limit():
|
||||
"""默认绕过限流,让非限流测试正常运行。"""
|
||||
"""Bypass rate limiting in most tests."""
|
||||
|
||||
with patch("app.core.rate_limiter.get_redis", new_callable=AsyncMock) as mock_redis:
|
||||
# 创建一个模拟的 Redis 客户端,所有操作返回安全默认值
|
||||
redis_instance = AsyncMock()
|
||||
redis_instance.incr.return_value = 1 # 始终返回 1 (不触发限流)
|
||||
redis_instance.incr.return_value = 1
|
||||
redis_instance.expire.return_value = True
|
||||
redis_instance.get.return_value = None # 无锁定记录
|
||||
redis_instance.get.return_value = None
|
||||
redis_instance.ttl.return_value = 0
|
||||
redis_instance.delete.return_value = 1
|
||||
mock_redis.return_value = redis_instance
|
||||
yield redis_instance
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def isolated_story_audio_cache(tmp_path, monkeypatch):
|
||||
"""Use an isolated directory for cached story audio files."""
|
||||
|
||||
monkeypatch.setattr(settings, "story_audio_cache_dir", str(tmp_path / "audio"))
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_text_provider():
|
||||
"""Mock 文本生成适配器 API 调用。"""
|
||||
"""Mock text generation."""
|
||||
|
||||
from app.services.adapters.text.models import StoryOutput
|
||||
|
||||
mock_result = StoryOutput(
|
||||
mode="generated",
|
||||
title="小兔子的冒险",
|
||||
story_text="从前有一只小兔子...",
|
||||
story_text="从前有一只小兔子。",
|
||||
cover_prompt_suggestion="A cute rabbit",
|
||||
)
|
||||
|
||||
@@ -128,7 +204,8 @@ def mock_text_provider():
|
||||
|
||||
@pytest.fixture
|
||||
def mock_image_provider():
|
||||
"""Mock 图像生成。"""
|
||||
"""Mock image generation."""
|
||||
|
||||
with patch("app.services.story_service.generate_image", new_callable=AsyncMock) as mock:
|
||||
mock.return_value = "https://example.com/image.png"
|
||||
yield mock
|
||||
@@ -136,7 +213,8 @@ def mock_image_provider():
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tts_provider():
|
||||
"""Mock TTS。"""
|
||||
"""Mock text-to-speech generation."""
|
||||
|
||||
with patch("app.services.provider_router.text_to_speech", new_callable=AsyncMock) as mock:
|
||||
mock.return_value = b"fake-audio-bytes"
|
||||
yield mock
|
||||
@@ -144,7 +222,8 @@ def mock_tts_provider():
|
||||
|
||||
@pytest.fixture
|
||||
def mock_all_providers(mock_text_provider, mock_image_provider, mock_tts_provider):
|
||||
"""Mock 所有 AI 供应商。"""
|
||||
"""Group all mocked providers."""
|
||||
|
||||
return {
|
||||
"text_primary": mock_text_provider,
|
||||
"image_primary": mock_image_provider,
|
||||
|
||||
@@ -1,26 +1,41 @@
|
||||
"""故事 API 测试。"""
|
||||
"""Tests for story-related API endpoints."""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.core.config import settings
|
||||
from app.services.adapters.storybook.primary import Storybook, StorybookPage
|
||||
|
||||
# ── 注意 ──────────────────────────────────────────────────────────────────────
|
||||
# 以下路由尚未实现 (stories.py 中没有对应端点),相关测试标记为 skip:
|
||||
# GET /api/stories (列表)
|
||||
# GET /api/stories/{id} (详情)
|
||||
# DELETE /api/stories/{id} (删除)
|
||||
# POST /api/image/generate/{id} (封面图片生成)
|
||||
# GET /api/audio/{id} (音频)
|
||||
# 实现后请取消 skip 标记。
|
||||
|
||||
def build_storybook_output() -> Storybook:
|
||||
"""Create a reusable mocked storybook payload."""
|
||||
|
||||
return Storybook(
|
||||
title="森林里的发光冒险",
|
||||
main_character="小兔子露露",
|
||||
art_style="温暖水彩",
|
||||
cover_prompt="A glowing forest storybook cover",
|
||||
pages=[
|
||||
StorybookPage(
|
||||
page_number=1,
|
||||
text="露露第一次走进会发光的森林。",
|
||||
image_prompt="Lulu entering a glowing forest",
|
||||
),
|
||||
StorybookPage(
|
||||
page_number=2,
|
||||
text="她遇到了一只会唱歌的萤火虫。",
|
||||
image_prompt="Lulu meeting a singing firefly",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestStoryGenerate:
|
||||
"""故事生成测试。"""
|
||||
"""Tests for basic story generation."""
|
||||
|
||||
def test_generate_without_auth(self, client: TestClient):
|
||||
"""未登录时生成故事。"""
|
||||
response = client.post(
|
||||
"/api/stories/generate",
|
||||
json={"type": "keywords", "data": "小兔子, 森林"},
|
||||
@@ -28,7 +43,6 @@ class TestStoryGenerate:
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_generate_with_empty_data(self, auth_client: TestClient):
|
||||
"""空数据生成故事。"""
|
||||
response = auth_client.post(
|
||||
"/api/stories/generate",
|
||||
json={"type": "keywords", "data": ""},
|
||||
@@ -36,7 +50,6 @@ class TestStoryGenerate:
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_generate_with_invalid_type(self, auth_client: TestClient):
|
||||
"""无效类型生成故事。"""
|
||||
response = auth_client.post(
|
||||
"/api/stories/generate",
|
||||
json={"type": "invalid", "data": "test"},
|
||||
@@ -44,7 +57,6 @@ class TestStoryGenerate:
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_generate_story_success(self, auth_client: TestClient, mock_text_provider):
|
||||
"""成功生成故事。"""
|
||||
response = auth_client.post(
|
||||
"/api/stories/generate",
|
||||
json={"type": "keywords", "data": "小兔子, 森林, 勇气"},
|
||||
@@ -55,82 +67,96 @@ class TestStoryGenerate:
|
||||
assert "title" in data
|
||||
assert "story_text" in data
|
||||
assert data["mode"] == "generated"
|
||||
assert data["generation_status"] == "narrative_ready"
|
||||
assert data["image_status"] == "not_requested"
|
||||
assert data["audio_status"] == "not_requested"
|
||||
assert data["last_error"] is None
|
||||
|
||||
|
||||
class TestStoryList:
|
||||
"""故事列表测试。"""
|
||||
"""Tests for story listing."""
|
||||
|
||||
def test_list_without_auth(self, client: TestClient):
|
||||
"""未登录时获取列表。"""
|
||||
response = client.get("/api/stories")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_list_empty(self, auth_client: TestClient):
|
||||
"""空列表。"""
|
||||
response = auth_client.get("/api/stories")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
def test_list_with_stories(self, auth_client: TestClient, test_story):
|
||||
"""有故事时获取列表。"""
|
||||
response = auth_client.get("/api/stories")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["id"] == test_story.id
|
||||
assert data[0]["title"] == test_story.title
|
||||
assert data[0]["generation_status"] == "narrative_ready"
|
||||
assert data[0]["image_status"] == "not_requested"
|
||||
assert data[0]["audio_status"] == "not_requested"
|
||||
|
||||
def test_list_pagination(self, auth_client: TestClient, test_story):
|
||||
"""分页测试。"""
|
||||
response = auth_client.get("/api/stories?limit=1&offset=0")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 1
|
||||
assert len(response.json()) == 1
|
||||
|
||||
response = auth_client.get("/api/stories?limit=1&offset=1")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 0
|
||||
assert len(response.json()) == 0
|
||||
|
||||
|
||||
class TestStoryDetail:
|
||||
"""故事详情测试。"""
|
||||
"""Tests for story detail retrieval."""
|
||||
|
||||
def test_get_story_without_auth(self, client: TestClient, test_story):
|
||||
"""未登录时获取详情。"""
|
||||
response = client.get(f"/api/stories/{test_story.id}")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_get_story_not_found(self, auth_client: TestClient):
|
||||
"""故事不存在。"""
|
||||
response = auth_client.get("/api/stories/99999")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_get_story_success(self, auth_client: TestClient, test_story):
|
||||
"""成功获取详情。"""
|
||||
response = auth_client.get(f"/api/stories/{test_story.id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == test_story.id
|
||||
assert data["title"] == test_story.title
|
||||
assert data["story_text"] == test_story.story_text
|
||||
assert data["generation_status"] == "narrative_ready"
|
||||
assert data["image_status"] == "not_requested"
|
||||
assert data["audio_status"] == "not_requested"
|
||||
assert data["last_error"] is None
|
||||
|
||||
def test_get_storybook_success(self, auth_client: TestClient, storybook_story):
|
||||
response = auth_client.get(f"/api/stories/{storybook_story.id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == storybook_story.id
|
||||
assert data["mode"] == "storybook"
|
||||
assert data["story_text"] is None
|
||||
assert len(data["pages"]) == 2
|
||||
assert data["pages"][0]["page_number"] == 1
|
||||
assert data["image_url"] == "https://example.com/storybook-cover.png"
|
||||
assert data["generation_status"] == "degraded_completed"
|
||||
assert data["image_status"] == "failed"
|
||||
assert data["audio_status"] == "not_requested"
|
||||
assert "第 2 页" in data["last_error"]
|
||||
|
||||
|
||||
class TestStoryDelete:
|
||||
"""故事删除测试。"""
|
||||
"""Tests for story deletion."""
|
||||
|
||||
def test_delete_without_auth(self, client: TestClient, test_story):
|
||||
"""未登录时删除。"""
|
||||
response = client.delete(f"/api/stories/{test_story.id}")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_delete_not_found(self, auth_client: TestClient):
|
||||
"""删除不存在的故事。"""
|
||||
response = auth_client.delete("/api/stories/99999")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_delete_success(self, auth_client: TestClient, test_story):
|
||||
"""成功删除故事。"""
|
||||
response = auth_client.delete(f"/api/stories/{test_story.id}")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["message"] == "Deleted"
|
||||
@@ -140,11 +166,14 @@ class TestStoryDelete:
|
||||
|
||||
|
||||
class TestRateLimit:
|
||||
"""Rate limit 测试。"""
|
||||
"""Tests for story generation rate limiting."""
|
||||
|
||||
def test_rate_limit_allows_normal_requests(self, auth_client: TestClient, mock_text_provider, bypass_rate_limit):
|
||||
"""正常请求不触发限流。"""
|
||||
# bypass_rate_limit 默认 incr 返回 1,不触发限流
|
||||
def test_rate_limit_allows_normal_requests(
|
||||
self,
|
||||
auth_client: TestClient,
|
||||
mock_text_provider,
|
||||
bypass_rate_limit,
|
||||
):
|
||||
for _ in range(3):
|
||||
response = auth_client.post(
|
||||
"/api/stories/generate",
|
||||
@@ -152,9 +181,11 @@ class TestRateLimit:
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_rate_limit_blocks_excess_requests(self, auth_client: TestClient, bypass_rate_limit):
|
||||
"""超限请求被阻止。"""
|
||||
# 让 incr 返回超限值 (> RATE_LIMIT_REQUESTS)
|
||||
def test_rate_limit_blocks_excess_requests(
|
||||
self,
|
||||
auth_client: TestClient,
|
||||
bypass_rate_limit,
|
||||
):
|
||||
bypass_rate_limit.incr.return_value = 11
|
||||
|
||||
response = auth_client.post(
|
||||
@@ -166,52 +197,118 @@ class TestRateLimit:
|
||||
|
||||
|
||||
class TestImageGenerate:
|
||||
"""封面图片生成测试。"""
|
||||
"""Tests for cover generation endpoint."""
|
||||
|
||||
def test_generate_image_without_auth(self, client: TestClient, test_story):
|
||||
"""未登录时生成图片。"""
|
||||
response = client.post(f"/api/image/generate/{test_story.id}")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_generate_image_not_found(self, auth_client: TestClient):
|
||||
"""故事不存在。"""
|
||||
response = auth_client.post("/api/image/generate/99999")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
class TestAudio:
|
||||
"""语音朗读测试。"""
|
||||
"""Tests for story audio endpoint."""
|
||||
|
||||
def test_get_audio_without_auth(self, client: TestClient, test_story):
|
||||
"""未登录时获取音频。"""
|
||||
response = client.get(f"/api/audio/{test_story.id}")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_get_audio_not_found(self, auth_client: TestClient):
|
||||
"""故事不存在。"""
|
||||
response = auth_client.get("/api/audio/99999")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_get_audio_success(self, auth_client: TestClient, test_story, mock_tts_provider):
|
||||
"""成功获取音频。"""
|
||||
def test_get_audio_success(
|
||||
self,
|
||||
auth_client: TestClient,
|
||||
test_story,
|
||||
mock_tts_provider,
|
||||
):
|
||||
response = auth_client.get(f"/api/audio/{test_story.id}")
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/mpeg"
|
||||
assert response.content == b"fake-audio-bytes"
|
||||
|
||||
cached_audio_path = Path(settings.story_audio_cache_dir) / f"story-{test_story.id}.mp3"
|
||||
assert cached_audio_path.is_file()
|
||||
|
||||
second_response = auth_client.get(f"/api/audio/{test_story.id}")
|
||||
assert second_response.status_code == 200
|
||||
assert second_response.content == b"fake-audio-bytes"
|
||||
mock_tts_provider.assert_awaited_once()
|
||||
|
||||
detail_response = auth_client.get(f"/api/stories/{test_story.id}")
|
||||
detail = detail_response.json()
|
||||
assert detail["audio_status"] == "ready"
|
||||
assert detail["generation_status"] == "completed"
|
||||
assert detail["last_error"] is None
|
||||
|
||||
def test_get_audio_regenerates_when_cache_file_is_missing(
|
||||
self,
|
||||
auth_client: TestClient,
|
||||
test_story,
|
||||
mock_tts_provider,
|
||||
):
|
||||
first_response = auth_client.get(f"/api/audio/{test_story.id}")
|
||||
assert first_response.status_code == 200
|
||||
|
||||
cached_audio_path = Path(settings.story_audio_cache_dir) / f"story-{test_story.id}.mp3"
|
||||
cached_audio_path.unlink()
|
||||
mock_tts_provider.reset_mock()
|
||||
|
||||
second_response = auth_client.get(f"/api/audio/{test_story.id}")
|
||||
assert second_response.status_code == 200
|
||||
assert second_response.content == b"fake-audio-bytes"
|
||||
assert cached_audio_path.is_file()
|
||||
mock_tts_provider.assert_awaited_once()
|
||||
|
||||
def test_get_audio_failure_updates_status(self, auth_client: TestClient, test_story):
|
||||
with patch("app.services.provider_router.text_to_speech", new_callable=AsyncMock) as mock_tts:
|
||||
mock_tts.side_effect = Exception("TTS provider timeout")
|
||||
response = auth_client.get(f"/api/audio/{test_story.id}")
|
||||
assert response.status_code == 500
|
||||
|
||||
detail_response = auth_client.get(f"/api/stories/{test_story.id}")
|
||||
detail = detail_response.json()
|
||||
assert detail["audio_status"] == "failed"
|
||||
assert detail["generation_status"] == "degraded_completed"
|
||||
assert "TTS provider timeout" in detail["last_error"]
|
||||
|
||||
def test_get_audio_success_preserves_existing_image_error(
|
||||
self,
|
||||
auth_client: TestClient,
|
||||
degraded_story_with_text,
|
||||
mock_tts_provider,
|
||||
):
|
||||
response = auth_client.get(f"/api/audio/{degraded_story_with_text.id}")
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"fake-audio-bytes"
|
||||
mock_tts_provider.assert_awaited_once()
|
||||
|
||||
detail_response = auth_client.get(f"/api/stories/{degraded_story_with_text.id}")
|
||||
detail = detail_response.json()
|
||||
assert detail["audio_status"] == "ready"
|
||||
assert detail["generation_status"] == "degraded_completed"
|
||||
assert detail["last_error"] == "封面生成失败"
|
||||
|
||||
|
||||
class TestGenerateFull:
|
||||
"""完整故事生成测试(/api/stories/generate/full)。"""
|
||||
"""Tests for complete story generation."""
|
||||
|
||||
def test_generate_full_without_auth(self, client: TestClient):
|
||||
"""未登录时生成完整故事。"""
|
||||
response = client.post(
|
||||
"/api/stories/generate/full",
|
||||
json={"type": "keywords", "data": "小兔子, 森林"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_generate_full_success(self, auth_client: TestClient, mock_text_provider, mock_image_provider):
|
||||
"""成功生成完整故事(含图片)。"""
|
||||
def test_generate_full_success(
|
||||
self,
|
||||
auth_client: TestClient,
|
||||
mock_text_provider,
|
||||
mock_image_provider,
|
||||
):
|
||||
response = auth_client.post(
|
||||
"/api/stories/generate/full",
|
||||
json={"type": "keywords", "data": "小兔子, 森林, 勇气"},
|
||||
@@ -223,11 +320,14 @@ class TestGenerateFull:
|
||||
assert "story_text" in data
|
||||
assert data["mode"] == "generated"
|
||||
assert data["image_url"] == "https://example.com/image.png"
|
||||
assert data["audio_ready"] is False # 音频按需生成
|
||||
assert data["audio_ready"] is False
|
||||
assert data["errors"] == {}
|
||||
assert data["generation_status"] == "completed"
|
||||
assert data["image_status"] == "ready"
|
||||
assert data["audio_status"] == "not_requested"
|
||||
assert data["last_error"] is None
|
||||
|
||||
def test_generate_full_image_failure(self, auth_client: TestClient, mock_text_provider):
|
||||
"""图片生成失败时返回部分成功。"""
|
||||
with patch("app.services.story_service.generate_image", new_callable=AsyncMock) as mock_img:
|
||||
mock_img.side_effect = Exception("Image API error")
|
||||
response = auth_client.post(
|
||||
@@ -239,9 +339,17 @@ class TestGenerateFull:
|
||||
assert data["image_url"] is None
|
||||
assert "image" in data["errors"]
|
||||
assert "Image API error" in data["errors"]["image"]
|
||||
assert data["generation_status"] == "degraded_completed"
|
||||
assert data["image_status"] == "failed"
|
||||
assert data["audio_status"] == "not_requested"
|
||||
assert "Image API error" in data["last_error"]
|
||||
|
||||
def test_generate_full_with_education_theme(self, auth_client: TestClient, mock_text_provider, mock_image_provider):
|
||||
"""带教育主题生成故事。"""
|
||||
def test_generate_full_with_education_theme(
|
||||
self,
|
||||
auth_client: TestClient,
|
||||
mock_text_provider,
|
||||
mock_image_provider,
|
||||
):
|
||||
response = auth_client.post(
|
||||
"/api/stories/generate/full",
|
||||
json={
|
||||
@@ -257,11 +365,80 @@ class TestGenerateFull:
|
||||
|
||||
|
||||
class TestImageGenerateSuccess:
|
||||
"""封面图片生成成功测试。"""
|
||||
"""Tests for successful cover generation."""
|
||||
|
||||
def test_generate_image_success(self, auth_client: TestClient, test_story, mock_image_provider):
|
||||
"""成功生成图片。"""
|
||||
def test_generate_image_success(
|
||||
self,
|
||||
auth_client: TestClient,
|
||||
test_story,
|
||||
mock_image_provider,
|
||||
):
|
||||
response = auth_client.post(f"/api/image/generate/{test_story.id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["image_url"] == "https://example.com/image.png"
|
||||
assert data["generation_status"] == "completed"
|
||||
assert data["image_status"] == "ready"
|
||||
assert data["audio_status"] == "not_requested"
|
||||
assert data["last_error"] is None
|
||||
|
||||
|
||||
class TestStorybookGenerate:
|
||||
"""Tests for storybook generation status handling."""
|
||||
|
||||
def test_generate_storybook_success(self, auth_client: TestClient):
|
||||
with patch("app.services.story_service.generate_storybook", new_callable=AsyncMock) as mock_storybook:
|
||||
with patch("app.services.story_service.generate_image", new_callable=AsyncMock) as mock_image:
|
||||
mock_storybook.return_value = build_storybook_output()
|
||||
mock_image.side_effect = [
|
||||
"https://example.com/storybook-cover.png",
|
||||
"https://example.com/storybook-page-1.png",
|
||||
"https://example.com/storybook-page-2.png",
|
||||
]
|
||||
|
||||
response = auth_client.post(
|
||||
"/api/storybook/generate",
|
||||
json={
|
||||
"keywords": "森林, 发光, 友情",
|
||||
"page_count": 6,
|
||||
"generate_images": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] is not None
|
||||
assert data["generation_status"] == "completed"
|
||||
assert data["image_status"] == "ready"
|
||||
assert data["audio_status"] == "not_requested"
|
||||
assert data["last_error"] is None
|
||||
assert len(data["pages"]) == 2
|
||||
assert data["cover_url"] == "https://example.com/storybook-cover.png"
|
||||
|
||||
def test_generate_storybook_partial_image_failure(self, auth_client: TestClient):
|
||||
async def image_side_effect(prompt: str, **kwargs):
|
||||
if "singing firefly" in prompt:
|
||||
raise Exception("Image API error")
|
||||
slug = prompt.split()[0].lower()
|
||||
return f"https://example.com/{slug}.png"
|
||||
|
||||
with patch("app.services.story_service.generate_storybook", new_callable=AsyncMock) as mock_storybook:
|
||||
with patch("app.services.story_service.generate_image", new_callable=AsyncMock) as mock_image:
|
||||
mock_storybook.return_value = build_storybook_output()
|
||||
mock_image.side_effect = image_side_effect
|
||||
|
||||
response = auth_client.post(
|
||||
"/api/storybook/generate",
|
||||
json={
|
||||
"keywords": "森林, 发光, 友情",
|
||||
"page_count": 6,
|
||||
"generate_images": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["generation_status"] == "degraded_completed"
|
||||
assert data["image_status"] == "failed"
|
||||
assert data["audio_status"] == "not_requested"
|
||||
assert "第 2 页插图生成失败" in data["last_error"]
|
||||
|
||||
Reference in New Issue
Block a user