"""Pytest fixtures for backend tests.""" import os from collections.abc import AsyncGenerator from unittest.mock import AsyncMock, patch import pytest from fastapi.testclient import TestClient from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing") os.environ.setdefault("DATABASE_URL", "sqlite+aiosqlite:///:memory:") from app.core.config import settings from app.core.security import create_access_token from app.db.database import get_db from app.db.models import Base, Story, User from app.main import app @pytest.fixture async def async_engine(): """Create an in-memory database engine.""" engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield engine await engine.dispose() @pytest.fixture async def db_session(async_engine) -> AsyncGenerator[AsyncSession, None]: """Create a database session.""" session_factory = async_sessionmaker( async_engine, class_=AsyncSession, expire_on_commit=False ) async with session_factory() as session: yield session @pytest.fixture async def test_user(db_session: AsyncSession) -> User: """Create a test user.""" user = User( id="github:12345", name="Test User", avatar_url="https://example.com/avatar.png", provider="github", ) db_session.add(user) await db_session.commit() await db_session.refresh(user) return user @pytest.fixture async def test_story(db_session: AsyncSession, test_user: User) -> Story: """Create a plain generated story.""" story = Story( user_id=test_user.id, title="测试故事", story_text="从前有一只小兔子。", cover_prompt="A cute rabbit in a forest", mode="generated", generation_status="narrative_ready", image_status="not_requested", audio_status="not_requested", ) db_session.add(story) await db_session.commit() await db_session.refresh(story) return story @pytest.fixture async def storybook_story(db_session: AsyncSession, test_user: User) -> Story: """Create a storybook-mode story.""" story = Story( user_id=test_user.id, title="森林绘本冒险", story_text=None, pages=[ { "page_number": 1, "text": "小兔子走进了会发光的森林。", "image_prompt": "A glowing forest with a curious rabbit", "image_url": "https://example.com/page-1.png", }, { "page_number": 2, "text": "它遇见了一位会唱歌的萤火虫朋友。", "image_prompt": "A rabbit meeting a singing firefly", "image_url": None, }, ], cover_prompt="A magical forest storybook cover", image_url="https://example.com/storybook-cover.png", mode="storybook", generation_status="degraded_completed", image_status="failed", audio_status="not_requested", last_error="第 2 页插图生成失败", ) db_session.add(story) await db_session.commit() await db_session.refresh(story) return story @pytest.fixture async def degraded_story_with_text(db_session: AsyncSession, test_user: User) -> Story: """Create a readable story whose image generation already failed.""" story = Story( user_id=test_user.id, title="部分完成的测试故事", story_text="从前有一只小兔子继续冒险。", cover_prompt="A rabbit under the moon", mode="generated", generation_status="degraded_completed", image_status="failed", audio_status="not_requested", last_error="封面生成失败", ) db_session.add(story) await db_session.commit() await db_session.refresh(story) return story @pytest.fixture def auth_token(test_user: User) -> str: """Create a JWT token for the test user.""" return create_access_token({"sub": test_user.id}) @pytest.fixture def client(db_session: AsyncSession) -> TestClient: """Create a test client.""" async def override_get_db(): yield db_session app.dependency_overrides[get_db] = override_get_db with TestClient(app) as c: yield c app.dependency_overrides.clear() @pytest.fixture def auth_client(client: TestClient, auth_token: str) -> TestClient: """Create an authenticated test client.""" client.cookies.set("access_token", auth_token) return client @pytest.fixture(autouse=True) def bypass_rate_limit(): """Bypass rate limiting in most tests.""" with patch("app.core.rate_limiter.get_redis", new_callable=AsyncMock) as mock_redis: redis_instance = AsyncMock() redis_instance.incr.return_value = 1 redis_instance.expire.return_value = True redis_instance.get.return_value = None redis_instance.ttl.return_value = 0 redis_instance.delete.return_value = 1 mock_redis.return_value = redis_instance yield redis_instance @pytest.fixture(autouse=True) def isolated_story_audio_cache(tmp_path, monkeypatch): """Use an isolated directory for cached story audio files.""" monkeypatch.setattr(settings, "story_audio_cache_dir", str(tmp_path / "audio")) yield @pytest.fixture def mock_text_provider(): """Mock text generation.""" from app.services.adapters.text.models import StoryOutput mock_result = StoryOutput( mode="generated", title="小兔子的冒险", story_text="从前有一只小兔子。", cover_prompt_suggestion="A cute rabbit", ) with patch("app.services.story_service.generate_story_content", new_callable=AsyncMock) as mock: mock.return_value = mock_result yield mock @pytest.fixture def mock_image_provider(): """Mock image generation.""" with patch("app.services.story_service.generate_image", new_callable=AsyncMock) as mock: mock.return_value = "https://example.com/image.png" yield mock @pytest.fixture def mock_tts_provider(): """Mock text-to-speech generation.""" with patch("app.services.provider_router.text_to_speech", new_callable=AsyncMock) as mock: mock.return_value = b"fake-audio-bytes" yield mock @pytest.fixture def mock_all_providers(mock_text_provider, mock_image_provider, mock_tts_provider): """Group all mocked providers.""" return { "text_primary": mock_text_provider, "image_primary": mock_image_provider, "tts_primary": mock_tts_provider, }