Initial commit: clean project structure
- Backend: FastAPI + SQLAlchemy + Celery (Python 3.11+) - Frontend: Vue 3 + TypeScript + Pinia + Tailwind - Admin Frontend: separate Vue 3 app for management - Docker Compose: 9 services orchestration - Specs: design prototypes, memory system PRD, product roadmap Cleanup performed: - Removed temporary debug scripts from backend root - Removed deprecated admin_app.py (embedded UI) - Removed duplicate docs from admin-frontend - Updated .gitignore for Vite cache and egg-info
This commit is contained in:
1
backend/tests/__init__.py
Normal file
1
backend/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Tests package
|
||||
146
backend/tests/conftest.py
Normal file
146
backend/tests/conftest.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""测试配置和 fixtures。"""
|
||||
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing")
|
||||
os.environ.setdefault("DATABASE_URL", "sqlite+aiosqlite:///:memory:")
|
||||
|
||||
from app.core.security import create_access_token
|
||||
from app.api.stories import _request_log
|
||||
from app.db.database import get_db
|
||||
from app.db.models import Base, Story, User
|
||||
from app.main import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""创建内存数据库引擎。"""
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""创建数据库会话。"""
|
||||
session_factory = async_sessionmaker(
|
||||
async_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_user(db_session: AsyncSession) -> User:
|
||||
"""创建测试用户。"""
|
||||
user = User(
|
||||
id="github:12345",
|
||||
name="Test User",
|
||||
avatar_url="https://example.com/avatar.png",
|
||||
provider="github",
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_story(db_session: AsyncSession, test_user: User) -> Story:
|
||||
"""创建测试故事。"""
|
||||
story = Story(
|
||||
user_id=test_user.id,
|
||||
title="测试故事",
|
||||
story_text="从前有一只小兔子...",
|
||||
cover_prompt="A cute rabbit in a forest",
|
||||
mode="generated",
|
||||
)
|
||||
db_session.add(story)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(story)
|
||||
return story
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_token(test_user: User) -> str:
|
||||
"""生成测试用户的 JWT token。"""
|
||||
return create_access_token({"sub": test_user.id})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(db_session: AsyncSession) -> TestClient:
|
||||
"""创建测试客户端。"""
|
||||
|
||||
async def override_get_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
with TestClient(app) as c:
|
||||
yield c
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_client(client: TestClient, auth_token: str) -> TestClient:
|
||||
"""带认证的测试客户端。"""
|
||||
client.cookies.set("access_token", auth_token)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_rate_limit_cache():
|
||||
"""确保每个测试用例的限流缓存互不影响。"""
|
||||
_request_log.clear()
|
||||
yield
|
||||
_request_log.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_text_provider():
|
||||
"""Mock 文本生成适配器 API 调用。"""
|
||||
from app.services.adapters.text.models import StoryOutput
|
||||
|
||||
mock_result = StoryOutput(
|
||||
mode="generated",
|
||||
title="小兔子的冒险",
|
||||
story_text="从前有一只小兔子...",
|
||||
cover_prompt_suggestion="A cute rabbit",
|
||||
)
|
||||
|
||||
with patch("app.api.stories.generate_story_content", new_callable=AsyncMock) as mock:
|
||||
mock.return_value = mock_result
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_image_provider():
|
||||
"""Mock 图像生成。"""
|
||||
with patch("app.api.stories.generate_image", new_callable=AsyncMock) as mock:
|
||||
mock.return_value = "https://example.com/image.png"
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tts_provider():
|
||||
"""Mock TTS。"""
|
||||
with patch("app.api.stories.text_to_speech", new_callable=AsyncMock) as mock:
|
||||
mock.return_value = b"fake-audio-bytes"
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_all_providers(mock_text_provider, mock_image_provider, mock_tts_provider):
|
||||
"""Mock 所有 AI 供应商。"""
|
||||
return {
|
||||
"text_primary": mock_text_provider,
|
||||
"image_primary": mock_image_provider,
|
||||
"tts_primary": mock_tts_provider,
|
||||
}
|
||||
65
backend/tests/test_auth.py
Normal file
65
backend/tests/test_auth.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""认证相关测试。"""
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.core.security import create_access_token, decode_access_token
|
||||
|
||||
|
||||
class TestJWT:
|
||||
"""JWT token 测试。"""
|
||||
|
||||
def test_create_and_decode_token(self):
|
||||
"""测试 token 创建和解码。"""
|
||||
payload = {"sub": "github:12345"}
|
||||
token = create_access_token(payload)
|
||||
decoded = decode_access_token(token)
|
||||
assert decoded is not None
|
||||
assert decoded["sub"] == "github:12345"
|
||||
|
||||
def test_decode_invalid_token(self):
|
||||
"""测试无效 token 解码。"""
|
||||
result = decode_access_token("invalid-token")
|
||||
assert result is None
|
||||
|
||||
def test_decode_empty_token(self):
|
||||
"""测试空 token 解码。"""
|
||||
result = decode_access_token("")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestSession:
|
||||
"""Session 端点测试。"""
|
||||
|
||||
def test_session_without_auth(self, client: TestClient):
|
||||
"""未登录时获取 session。"""
|
||||
response = client.get("/auth/session")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user"] is None
|
||||
|
||||
def test_session_with_auth(self, auth_client: TestClient, test_user):
|
||||
"""已登录时获取 session。"""
|
||||
response = auth_client.get("/auth/session")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user"] is not None
|
||||
assert data["user"]["id"] == test_user.id
|
||||
assert data["user"]["name"] == test_user.name
|
||||
|
||||
def test_session_with_invalid_token(self, client: TestClient):
|
||||
"""无效 token 获取 session。"""
|
||||
client.cookies.set("access_token", "invalid-token")
|
||||
response = client.get("/auth/session")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user"] is None
|
||||
|
||||
|
||||
class TestSignout:
|
||||
"""登出测试。"""
|
||||
|
||||
def test_signout(self, auth_client: TestClient):
|
||||
"""测试登出。"""
|
||||
response = auth_client.post("/auth/signout", follow_redirects=False)
|
||||
assert response.status_code == 302
|
||||
78
backend/tests/test_profiles.py
Normal file
78
backend/tests/test_profiles.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Child profile API tests."""
|
||||
|
||||
from datetime import date
|
||||
|
||||
|
||||
def _calc_age(birth_date: date) -> int:
|
||||
today = date.today()
|
||||
return today.year - birth_date.year - (
|
||||
(today.month, today.day) < (birth_date.month, birth_date.day)
|
||||
)
|
||||
|
||||
|
||||
def test_list_profiles_empty(auth_client):
|
||||
response = auth_client.get("/api/profiles")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["profiles"] == []
|
||||
assert data["total"] == 0
|
||||
|
||||
|
||||
def test_create_update_delete_profile(auth_client):
|
||||
payload = {
|
||||
"name": "小明",
|
||||
"birth_date": "2020-05-12",
|
||||
"gender": "male",
|
||||
"interests": ["太空", "机器人"],
|
||||
"growth_themes": ["勇气"],
|
||||
}
|
||||
response = auth_client.post("/api/profiles", json=payload)
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["name"] == payload["name"]
|
||||
assert data["gender"] == payload["gender"]
|
||||
assert data["interests"] == payload["interests"]
|
||||
assert data["growth_themes"] == payload["growth_themes"]
|
||||
assert data["age"] == _calc_age(date.fromisoformat(payload["birth_date"]))
|
||||
|
||||
profile_id = data["id"]
|
||||
|
||||
update_payload = {"growth_themes": ["分享", "独立"]}
|
||||
response = auth_client.put(f"/api/profiles/{profile_id}", json=update_payload)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["growth_themes"] == update_payload["growth_themes"]
|
||||
|
||||
response = auth_client.delete(f"/api/profiles/{profile_id}")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["message"] == "Deleted"
|
||||
|
||||
|
||||
def test_profile_limit_and_duplicate(auth_client):
|
||||
# 先测试重复名称(在达到限制前)
|
||||
response = auth_client.post(
|
||||
"/api/profiles",
|
||||
json={"name": "孩子1", "gender": "female"},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
response = auth_client.post(
|
||||
"/api/profiles",
|
||||
json={"name": "孩子1", "gender": "female"},
|
||||
)
|
||||
assert response.status_code == 409 # 重复名称
|
||||
|
||||
# 继续创建到上限
|
||||
for i in range(2, 6):
|
||||
response = auth_client.post(
|
||||
"/api/profiles",
|
||||
json={"name": f"孩子{i}", "gender": "female"},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
# 测试数量限制
|
||||
response = auth_client.post(
|
||||
"/api/profiles",
|
||||
json={"name": "孩子6", "gender": "female"},
|
||||
)
|
||||
assert response.status_code == 400 # 超过5个限制
|
||||
195
backend/tests/test_provider_router.py
Normal file
195
backend/tests/test_provider_router.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""Provider router 测试 - failover 和配置加载。"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.adapters import AdapterConfig
|
||||
from app.services.adapters.text.models import StoryOutput
|
||||
|
||||
|
||||
class TestProviderFailover:
|
||||
"""Provider failover 测试。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failover_to_second_provider(self):
|
||||
"""第一个 provider 失败时切换到第二个。"""
|
||||
from app.services import provider_router
|
||||
|
||||
# Mock 两个 provider - 使用 spec=False 并显式设置所有属性
|
||||
mock_provider_1 = MagicMock()
|
||||
mock_provider_1.configure_mock(
|
||||
id="provider-1",
|
||||
type="text",
|
||||
adapter="text_primary",
|
||||
api_key="key1",
|
||||
api_base=None,
|
||||
model=None,
|
||||
timeout_ms=60000,
|
||||
max_retries=3,
|
||||
config_ref=None,
|
||||
config_json={},
|
||||
priority=10,
|
||||
weight=1.0,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
mock_provider_2 = MagicMock()
|
||||
mock_provider_2.configure_mock(
|
||||
id="provider-2",
|
||||
type="text",
|
||||
adapter="text_primary",
|
||||
api_key="key2",
|
||||
api_base=None,
|
||||
model=None,
|
||||
timeout_ms=60000,
|
||||
max_retries=3,
|
||||
config_ref=None,
|
||||
config_json={},
|
||||
priority=5,
|
||||
weight=1.0,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
mock_providers = [mock_provider_1, mock_provider_2]
|
||||
|
||||
mock_result = StoryOutput(
|
||||
mode="generated",
|
||||
title="测试故事",
|
||||
story_text="内容",
|
||||
cover_prompt_suggestion="prompt",
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise Exception("First provider failed")
|
||||
return mock_result
|
||||
|
||||
with patch.object(provider_router, "get_providers", return_value=mock_providers):
|
||||
with patch("app.services.adapters.AdapterRegistry.get") as mock_get:
|
||||
mock_adapter_class = MagicMock()
|
||||
mock_adapter_instance = MagicMock()
|
||||
mock_adapter_instance.execute = mock_execute
|
||||
mock_adapter_class.return_value = mock_adapter_instance
|
||||
mock_get.return_value = mock_adapter_class
|
||||
|
||||
result = await provider_router.generate_story_content(
|
||||
input_type="keywords",
|
||||
data="测试",
|
||||
)
|
||||
|
||||
assert result == mock_result
|
||||
assert call_count == 2 # 第一个失败,第二个成功
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_providers_fail(self):
|
||||
"""所有 provider 都失败时抛出异常。"""
|
||||
from app.services import provider_router
|
||||
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.configure_mock(
|
||||
id="provider-1",
|
||||
type="text",
|
||||
adapter="text_primary",
|
||||
api_key="key1",
|
||||
api_base=None,
|
||||
model=None,
|
||||
timeout_ms=60000,
|
||||
max_retries=3,
|
||||
config_ref=None,
|
||||
config_json={},
|
||||
priority=10,
|
||||
weight=1.0,
|
||||
enabled=True,
|
||||
)
|
||||
mock_providers = [mock_provider]
|
||||
|
||||
async def mock_execute(**kwargs):
|
||||
raise Exception("Provider failed")
|
||||
|
||||
with patch.object(provider_router, "get_providers", return_value=mock_providers):
|
||||
with patch("app.services.adapters.AdapterRegistry.get") as mock_get:
|
||||
mock_adapter_class = MagicMock()
|
||||
mock_adapter_instance = MagicMock()
|
||||
mock_adapter_instance.execute = mock_execute
|
||||
mock_adapter_class.return_value = mock_adapter_instance
|
||||
mock_get.return_value = mock_adapter_class
|
||||
|
||||
with pytest.raises(ValueError, match="No text provider succeeded"):
|
||||
await provider_router.generate_story_content(
|
||||
input_type="keywords",
|
||||
data="测试",
|
||||
)
|
||||
|
||||
|
||||
class TestProviderConfigFromDB:
|
||||
"""从 DB 加载 provider 配置测试。"""
|
||||
|
||||
def test_build_config_from_provider_with_api_key(self):
|
||||
"""Provider 有 api_key 时优先使用。"""
|
||||
from app.services.provider_router import _build_config_from_provider
|
||||
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.adapter = "text_primary"
|
||||
mock_provider.api_key = "db-api-key"
|
||||
mock_provider.api_base = "https://custom.api.com"
|
||||
mock_provider.model = "custom-model"
|
||||
mock_provider.timeout_ms = 30000
|
||||
mock_provider.max_retries = 5
|
||||
mock_provider.config_ref = None
|
||||
mock_provider.config_json = {}
|
||||
|
||||
config = _build_config_from_provider(mock_provider)
|
||||
|
||||
assert config.api_key == "db-api-key"
|
||||
assert config.api_base == "https://custom.api.com"
|
||||
assert config.model == "custom-model"
|
||||
assert config.timeout_ms == 30000
|
||||
assert config.max_retries == 5
|
||||
|
||||
def test_build_config_fallback_to_settings(self):
|
||||
"""Provider 无 api_key 时回退到 settings。"""
|
||||
from app.services.provider_router import _build_config_from_provider
|
||||
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.adapter = "text_primary"
|
||||
mock_provider.api_key = None
|
||||
mock_provider.api_base = None
|
||||
mock_provider.model = None
|
||||
mock_provider.timeout_ms = None
|
||||
mock_provider.max_retries = None
|
||||
mock_provider.config_ref = "text_api_key"
|
||||
mock_provider.config_json = {}
|
||||
|
||||
with patch("app.services.provider_router.settings") as mock_settings:
|
||||
mock_settings.text_api_key = "settings-api-key"
|
||||
mock_settings.text_model = "gemini-2.0-flash"
|
||||
|
||||
config = _build_config_from_provider(mock_provider)
|
||||
|
||||
assert config.api_key == "settings-api-key"
|
||||
|
||||
|
||||
class TestProviderCacheStartup:
|
||||
"""Provider cache 启动加载测试。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_loaded_on_startup(self):
|
||||
"""启动时加载 provider cache。"""
|
||||
from app.main import _load_provider_cache
|
||||
|
||||
with patch("app.db.database._get_session_factory") as mock_factory:
|
||||
mock_session = AsyncMock()
|
||||
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_factory.return_value.__aexit__ = AsyncMock()
|
||||
|
||||
with patch("app.services.provider_cache.reload_providers", new_callable=AsyncMock) as mock_reload:
|
||||
mock_reload.return_value = {"text": [], "image": [], "tts": []}
|
||||
|
||||
await _load_provider_cache()
|
||||
|
||||
mock_reload.assert_called_once()
|
||||
77
backend/tests/test_push_configs.py
Normal file
77
backend/tests/test_push_configs.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""Push config API tests."""
|
||||
|
||||
|
||||
def _create_profile(auth_client) -> str:
|
||||
response = auth_client.post(
|
||||
"/api/profiles",
|
||||
json={"name": "小明", "gender": "male"},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
return response.json()["id"]
|
||||
|
||||
|
||||
def test_create_list_update_push_config(auth_client):
|
||||
profile_id = _create_profile(auth_client)
|
||||
|
||||
response = auth_client.put(
|
||||
"/api/push-configs",
|
||||
json={
|
||||
"child_profile_id": profile_id,
|
||||
"push_time": "20:30",
|
||||
"push_days": [1, 3, 5],
|
||||
"enabled": True,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["child_profile_id"] == profile_id
|
||||
assert data["push_time"].startswith("20:30")
|
||||
assert data["push_days"] == [1, 3, 5]
|
||||
assert data["enabled"] is True
|
||||
|
||||
response = auth_client.get("/api/push-configs")
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["total"] == 1
|
||||
|
||||
response = auth_client.put(
|
||||
"/api/push-configs",
|
||||
json={
|
||||
"child_profile_id": profile_id,
|
||||
"enabled": False,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["enabled"] is False
|
||||
assert data["push_time"].startswith("20:30")
|
||||
assert data["push_days"] == [1, 3, 5]
|
||||
|
||||
|
||||
def test_push_config_validation(auth_client):
|
||||
profile_id = _create_profile(auth_client)
|
||||
|
||||
response = auth_client.put(
|
||||
"/api/push-configs",
|
||||
json={"child_profile_id": profile_id},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
response = auth_client.put(
|
||||
"/api/push-configs",
|
||||
json={
|
||||
"child_profile_id": profile_id,
|
||||
"push_time": "19:00",
|
||||
"push_days": [7],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
response = auth_client.put(
|
||||
"/api/push-configs",
|
||||
json={
|
||||
"child_profile_id": profile_id,
|
||||
"push_time": None,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
143
backend/tests/test_reading_events.py
Normal file
143
backend/tests/test_reading_events.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""Reading event API tests."""
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db.database import get_db
|
||||
from app.db.models import MemoryItem
|
||||
from app.main import app
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
async def _create_profile(client: AsyncClient) -> str:
|
||||
response = await client.post(
|
||||
"/api/profiles",
|
||||
json={"name": "小明", "gender": "male"},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
return response.json()["id"]
|
||||
|
||||
|
||||
async def test_create_reading_event_updates_stats_and_memory(
|
||||
db_session,
|
||||
test_user,
|
||||
auth_token,
|
||||
test_story,
|
||||
):
|
||||
async def override_get_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
transport = ASGITransport(app=app)
|
||||
|
||||
try:
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
client.cookies.set("access_token", auth_token)
|
||||
profile_id = await _create_profile(client)
|
||||
|
||||
response = await client.post(
|
||||
"/api/reading-events",
|
||||
json={
|
||||
"child_profile_id": profile_id,
|
||||
"story_id": test_story.id,
|
||||
"event_type": "completed",
|
||||
"reading_time": 120,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["child_profile_id"] == profile_id
|
||||
assert data["story_id"] == test_story.id
|
||||
assert data["event_type"] == "completed"
|
||||
|
||||
response = await client.get(f"/api/profiles/{profile_id}")
|
||||
assert response.status_code == 200
|
||||
profile = response.json()
|
||||
assert profile["stories_count"] == 1
|
||||
assert profile["total_reading_time"] == 120
|
||||
|
||||
result = await db_session.execute(
|
||||
select(MemoryItem).where(MemoryItem.child_profile_id == profile_id)
|
||||
)
|
||||
items = result.scalars().all()
|
||||
assert len(items) == 1
|
||||
assert items[0].type == "recent_story"
|
||||
assert items[0].value["story_id"] == test_story.id
|
||||
|
||||
response = await client.post(
|
||||
"/api/reading-events",
|
||||
json={
|
||||
"child_profile_id": profile_id,
|
||||
"story_id": test_story.id,
|
||||
"event_type": "skipped",
|
||||
"reading_time": 0,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
result = await db_session.execute(
|
||||
select(MemoryItem).where(MemoryItem.child_profile_id == profile_id)
|
||||
)
|
||||
assert len(result.scalars().all()) == 1
|
||||
|
||||
response = await client.post(
|
||||
"/api/reading-events",
|
||||
json={
|
||||
"child_profile_id": profile_id,
|
||||
"story_id": test_story.id,
|
||||
"event_type": "completed",
|
||||
"reading_time": 0,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
response = await client.get(f"/api/profiles/{profile_id}")
|
||||
assert response.status_code == 200
|
||||
profile = response.json()
|
||||
assert profile["stories_count"] == 1
|
||||
assert profile["total_reading_time"] == 120
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
async def test_reading_event_validation_errors(
|
||||
db_session,
|
||||
test_user,
|
||||
auth_token,
|
||||
):
|
||||
async def override_get_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
transport = ASGITransport(app=app)
|
||||
|
||||
try:
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
client.cookies.set("access_token", auth_token)
|
||||
|
||||
response = await client.post(
|
||||
"/api/reading-events",
|
||||
json={
|
||||
"child_profile_id": "not-exist",
|
||||
"event_type": "started",
|
||||
"reading_time": 0,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
profile_id = await _create_profile(client)
|
||||
|
||||
response = await client.post(
|
||||
"/api/reading-events",
|
||||
json={
|
||||
"child_profile_id": profile_id,
|
||||
"story_id": 999999,
|
||||
"event_type": "completed",
|
||||
"reading_time": 0,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
257
backend/tests/test_stories.py
Normal file
257
backend/tests/test_stories.py
Normal file
@@ -0,0 +1,257 @@
|
||||
"""故事 API 测试。"""
|
||||
|
||||
import time
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.api.stories import _request_log, RATE_LIMIT_REQUESTS
|
||||
|
||||
|
||||
class TestStoryGenerate:
|
||||
"""故事生成测试。"""
|
||||
|
||||
def test_generate_without_auth(self, client: TestClient):
|
||||
"""未登录时生成故事。"""
|
||||
response = client.post(
|
||||
"/api/generate",
|
||||
json={"type": "keywords", "data": "小兔子, 森林"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_generate_with_empty_data(self, auth_client: TestClient):
|
||||
"""空数据生成故事。"""
|
||||
response = auth_client.post(
|
||||
"/api/generate",
|
||||
json={"type": "keywords", "data": ""},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_generate_with_invalid_type(self, auth_client: TestClient):
|
||||
"""无效类型生成故事。"""
|
||||
response = auth_client.post(
|
||||
"/api/generate",
|
||||
json={"type": "invalid", "data": "test"},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_generate_story_success(self, auth_client: TestClient, mock_text_provider):
|
||||
"""成功生成故事。"""
|
||||
response = auth_client.post(
|
||||
"/api/generate",
|
||||
json={"type": "keywords", "data": "小兔子, 森林, 勇气"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "id" in data
|
||||
assert "title" in data
|
||||
assert "story_text" in data
|
||||
assert data["mode"] == "generated"
|
||||
|
||||
|
||||
class TestStoryList:
|
||||
"""故事列表测试。"""
|
||||
|
||||
def test_list_without_auth(self, client: TestClient):
|
||||
"""未登录时获取列表。"""
|
||||
response = client.get("/api/stories")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_list_empty(self, auth_client: TestClient):
|
||||
"""空列表。"""
|
||||
response = auth_client.get("/api/stories")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
def test_list_with_stories(self, auth_client: TestClient, test_story):
|
||||
"""有故事时获取列表。"""
|
||||
response = auth_client.get("/api/stories")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["id"] == test_story.id
|
||||
assert data[0]["title"] == test_story.title
|
||||
|
||||
def test_list_pagination(self, auth_client: TestClient, test_story):
|
||||
"""分页测试。"""
|
||||
response = auth_client.get("/api/stories?limit=1&offset=0")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 1
|
||||
|
||||
response = auth_client.get("/api/stories?limit=1&offset=1")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 0
|
||||
|
||||
|
||||
class TestStoryDetail:
|
||||
"""故事详情测试。"""
|
||||
|
||||
def test_get_story_without_auth(self, client: TestClient, test_story):
|
||||
"""未登录时获取详情。"""
|
||||
response = client.get(f"/api/stories/{test_story.id}")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_get_story_not_found(self, auth_client: TestClient):
|
||||
"""故事不存在。"""
|
||||
response = auth_client.get("/api/stories/99999")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_get_story_success(self, auth_client: TestClient, test_story):
|
||||
"""成功获取详情。"""
|
||||
response = auth_client.get(f"/api/stories/{test_story.id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == test_story.id
|
||||
assert data["title"] == test_story.title
|
||||
assert data["story_text"] == test_story.story_text
|
||||
|
||||
|
||||
class TestStoryDelete:
|
||||
"""故事删除测试。"""
|
||||
|
||||
def test_delete_without_auth(self, client: TestClient, test_story):
|
||||
"""未登录时删除。"""
|
||||
response = client.delete(f"/api/stories/{test_story.id}")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_delete_not_found(self, auth_client: TestClient):
|
||||
"""删除不存在的故事。"""
|
||||
response = auth_client.delete("/api/stories/99999")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_delete_success(self, auth_client: TestClient, test_story):
|
||||
"""成功删除故事。"""
|
||||
response = auth_client.delete(f"/api/stories/{test_story.id}")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["message"] == "Deleted"
|
||||
|
||||
response = auth_client.get(f"/api/stories/{test_story.id}")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
class TestRateLimit:
|
||||
"""Rate limit 测试。"""
|
||||
|
||||
def setup_method(self):
|
||||
"""每个测试前清理 rate limit 缓存。"""
|
||||
_request_log.clear()
|
||||
|
||||
def test_rate_limit_allows_normal_requests(self, auth_client: TestClient, test_story):
|
||||
"""正常请求不触发限流。"""
|
||||
for _ in range(RATE_LIMIT_REQUESTS - 1):
|
||||
response = auth_client.get(f"/api/stories/{test_story.id}")
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_rate_limit_blocks_excess_requests(self, auth_client: TestClient, test_story):
|
||||
"""超限请求被阻止。"""
|
||||
for _ in range(RATE_LIMIT_REQUESTS):
|
||||
auth_client.get(f"/api/stories/{test_story.id}")
|
||||
|
||||
response = auth_client.get(f"/api/stories/{test_story.id}")
|
||||
assert response.status_code == 429
|
||||
assert "Too many requests" in response.json()["detail"]
|
||||
|
||||
|
||||
class TestImageGenerate:
|
||||
"""封面图片生成测试。"""
|
||||
|
||||
def test_generate_image_without_auth(self, client: TestClient, test_story):
|
||||
"""未登录时生成图片。"""
|
||||
response = client.post(f"/api/image/generate/{test_story.id}")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_generate_image_not_found(self, auth_client: TestClient):
|
||||
"""故事不存在。"""
|
||||
response = auth_client.post("/api/image/generate/99999")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
class TestAudio:
|
||||
"""语音朗读测试。"""
|
||||
|
||||
def test_get_audio_without_auth(self, client: TestClient, test_story):
|
||||
"""未登录时获取音频。"""
|
||||
response = client.get(f"/api/audio/{test_story.id}")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_get_audio_not_found(self, auth_client: TestClient):
|
||||
"""故事不存在。"""
|
||||
response = auth_client.get("/api/audio/99999")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_get_audio_success(self, auth_client: TestClient, test_story, mock_tts_provider):
|
||||
"""成功获取音频。"""
|
||||
response = auth_client.get(f"/api/audio/{test_story.id}")
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/mpeg"
|
||||
|
||||
|
||||
class TestGenerateFull:
|
||||
"""完整故事生成测试(/api/generate/full)。"""
|
||||
|
||||
def test_generate_full_without_auth(self, client: TestClient):
|
||||
"""未登录时生成完整故事。"""
|
||||
response = client.post(
|
||||
"/api/generate/full",
|
||||
json={"type": "keywords", "data": "小兔子, 森林"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_generate_full_success(self, auth_client: TestClient, mock_text_provider, mock_image_provider):
|
||||
"""成功生成完整故事(含图片)。"""
|
||||
response = auth_client.post(
|
||||
"/api/generate/full",
|
||||
json={"type": "keywords", "data": "小兔子, 森林, 勇气"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "id" in data
|
||||
assert "title" in data
|
||||
assert "story_text" in data
|
||||
assert data["mode"] == "generated"
|
||||
assert data["image_url"] == "https://example.com/image.png"
|
||||
assert data["audio_ready"] is False # 音频按需生成
|
||||
assert data["errors"] == {}
|
||||
|
||||
def test_generate_full_image_failure(self, auth_client: TestClient, mock_text_provider):
|
||||
"""图片生成失败时返回部分成功。"""
|
||||
with patch("app.api.stories.generate_image", new_callable=AsyncMock) as mock_img:
|
||||
mock_img.side_effect = Exception("Image API error")
|
||||
response = auth_client.post(
|
||||
"/api/generate/full",
|
||||
json={"type": "keywords", "data": "小兔子, 森林"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["image_url"] is None
|
||||
assert "image" in data["errors"]
|
||||
assert "Image API error" in data["errors"]["image"]
|
||||
|
||||
def test_generate_full_with_education_theme(self, auth_client: TestClient, mock_text_provider, mock_image_provider):
|
||||
"""带教育主题生成故事。"""
|
||||
response = auth_client.post(
|
||||
"/api/generate/full",
|
||||
json={
|
||||
"type": "keywords",
|
||||
"data": "小兔子, 森林",
|
||||
"education_theme": "勇气与友谊",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mock_text_provider.assert_called_once()
|
||||
call_kwargs = mock_text_provider.call_args.kwargs
|
||||
assert call_kwargs["education_theme"] == "勇气与友谊"
|
||||
|
||||
|
||||
class TestImageGenerateSuccess:
|
||||
"""封面图片生成成功测试。"""
|
||||
|
||||
def test_generate_image_success(self, auth_client: TestClient, test_story, mock_image_provider):
|
||||
"""成功生成图片。"""
|
||||
response = auth_client.post(f"/api/image/generate/{test_story.id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["image_url"] == "https://example.com/image.png"
|
||||
68
backend/tests/test_universes.py
Normal file
68
backend/tests/test_universes.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Story universe API tests."""
|
||||
|
||||
|
||||
def _create_profile(auth_client):
|
||||
response = auth_client.post(
|
||||
"/api/profiles",
|
||||
json={
|
||||
"name": "小明",
|
||||
"gender": "male",
|
||||
"interests": ["太空"],
|
||||
"growth_themes": ["勇气"],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
return response.json()["id"]
|
||||
|
||||
|
||||
def test_create_list_update_universe(auth_client):
|
||||
profile_id = _create_profile(auth_client)
|
||||
|
||||
payload = {
|
||||
"name": "星际冒险",
|
||||
"protagonist": {"name": "小明", "role": "船长"},
|
||||
"recurring_characters": [{"name": "小七", "role": "机器人"}],
|
||||
"world_settings": {"world_name": "星际学院"},
|
||||
}
|
||||
|
||||
response = auth_client.post(f"/api/profiles/{profile_id}/universes", json=payload)
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["name"] == payload["name"]
|
||||
|
||||
universe_id = data["id"]
|
||||
|
||||
response = auth_client.get(f"/api/profiles/{profile_id}/universes")
|
||||
assert response.status_code == 200
|
||||
list_data = response.json()
|
||||
assert list_data["total"] == 1
|
||||
|
||||
response = auth_client.put(
|
||||
f"/api/universes/{universe_id}",
|
||||
json={"name": "星际冒险·第二季"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "星际冒险·第二季"
|
||||
|
||||
|
||||
def test_add_achievement(auth_client):
|
||||
profile_id = _create_profile(auth_client)
|
||||
|
||||
response = auth_client.post(
|
||||
f"/api/profiles/{profile_id}/universes",
|
||||
json={
|
||||
"name": "梦幻森林",
|
||||
"protagonist": {"name": "小红"},
|
||||
},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
universe_id = response.json()["id"]
|
||||
|
||||
response = auth_client.post(
|
||||
f"/api/universes/{universe_id}/achievements",
|
||||
json={"type": "勇气", "description": "克服黑暗"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert {"type": "勇气", "description": "克服黑暗"} in data["achievements"]
|
||||
Reference in New Issue
Block a user