Files
dreamweaver/backend/tests/conftest.py

153 lines
4.5 KiB
Python

"""测试配置和 fixtures。"""
import os
from collections.abc import AsyncGenerator
from unittest.mock import AsyncMock, patch
import pytest
from fastapi.testclient import TestClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing")
os.environ.setdefault("DATABASE_URL", "sqlite+aiosqlite:///:memory:")
from app.core.security import create_access_token
from app.db.database import get_db
from app.db.models import Base, Story, User
from app.main import app
@pytest.fixture
async def async_engine():
"""创建内存数据库引擎。"""
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def db_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
"""创建数据库会话。"""
session_factory = async_sessionmaker(
async_engine, class_=AsyncSession, expire_on_commit=False
)
async with session_factory() as session:
yield session
@pytest.fixture
async def test_user(db_session: AsyncSession) -> User:
"""创建测试用户。"""
user = User(
id="github:12345",
name="Test User",
avatar_url="https://example.com/avatar.png",
provider="github",
)
db_session.add(user)
await db_session.commit()
await db_session.refresh(user)
return user
@pytest.fixture
async def test_story(db_session: AsyncSession, test_user: User) -> Story:
"""创建测试故事。"""
story = Story(
user_id=test_user.id,
title="测试故事",
story_text="从前有一只小兔子...",
cover_prompt="A cute rabbit in a forest",
mode="generated",
)
db_session.add(story)
await db_session.commit()
await db_session.refresh(story)
return story
@pytest.fixture
def auth_token(test_user: User) -> str:
"""生成测试用户的 JWT token。"""
return create_access_token({"sub": test_user.id})
@pytest.fixture
def client(db_session: AsyncSession) -> TestClient:
"""创建测试客户端。"""
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
with TestClient(app) as c:
yield c
app.dependency_overrides.clear()
@pytest.fixture
def auth_client(client: TestClient, auth_token: str) -> TestClient:
"""带认证的测试客户端。"""
client.cookies.set("access_token", auth_token)
return client
@pytest.fixture(autouse=True)
def bypass_rate_limit():
"""默认绕过限流,让非限流测试正常运行。"""
with patch("app.core.rate_limiter.get_redis", new_callable=AsyncMock) as mock_redis:
# 创建一个模拟的 Redis 客户端,所有操作返回安全默认值
redis_instance = AsyncMock()
redis_instance.incr.return_value = 1 # 始终返回 1 (不触发限流)
redis_instance.expire.return_value = True
redis_instance.get.return_value = None # 无锁定记录
redis_instance.ttl.return_value = 0
redis_instance.delete.return_value = 1
mock_redis.return_value = redis_instance
yield redis_instance
@pytest.fixture
def mock_text_provider():
"""Mock 文本生成适配器 API 调用。"""
from app.services.adapters.text.models import StoryOutput
mock_result = StoryOutput(
mode="generated",
title="小兔子的冒险",
story_text="从前有一只小兔子...",
cover_prompt_suggestion="A cute rabbit",
)
with patch("app.services.story_service.generate_story_content", new_callable=AsyncMock) as mock:
mock.return_value = mock_result
yield mock
@pytest.fixture
def mock_image_provider():
"""Mock 图像生成。"""
with patch("app.services.story_service.generate_image", new_callable=AsyncMock) as mock:
mock.return_value = "https://example.com/image.png"
yield mock
@pytest.fixture
def mock_tts_provider():
"""Mock TTS。"""
with patch("app.services.provider_router.text_to_speech", new_callable=AsyncMock) as mock:
mock.return_value = b"fake-audio-bytes"
yield mock
@pytest.fixture
def mock_all_providers(mock_text_provider, mock_image_provider, mock_tts_provider):
"""Mock 所有 AI 供应商。"""
return {
"text_primary": mock_text_provider,
"image_primary": mock_image_provider,
"tts_primary": mock_tts_provider,
}