"""Provider router 测试 - failover 和配置加载。""" from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch import pytest from app.services.adapters import AdapterConfig from app.services.adapters.text.models import StoryOutput from app.services.provider_policy import ( DEFAULT_PROVIDERS, RoutingStrategy, get_provider_names_from_settings, list_capability_policies, ) class TestProviderFailover: """Provider failover 测试。""" @pytest.mark.asyncio async def test_failover_to_second_provider(self): """第一个 provider 失败时切换到第二个。""" from app.services import provider_router # Mock 两个 provider - 使用 spec=False 并显式设置所有属性 mock_provider_1 = MagicMock() mock_provider_1.configure_mock( id="provider-1", type="text", adapter="text_primary", api_key="key1", api_base=None, model=None, timeout_ms=60000, max_retries=3, config_ref=None, config_json={}, priority=10, weight=1.0, enabled=True, ) mock_provider_2 = MagicMock() mock_provider_2.configure_mock( id="provider-2", type="text", adapter="text_primary", api_key="key2", api_base=None, model=None, timeout_ms=60000, max_retries=3, config_ref=None, config_json={}, priority=5, weight=1.0, enabled=True, ) mock_providers = [mock_provider_1, mock_provider_2] mock_result = StoryOutput( mode="generated", title="测试故事", story_text="内容", cover_prompt_suggestion="prompt", ) call_count = 0 async def mock_execute(**kwargs): nonlocal call_count call_count += 1 if call_count == 1: raise Exception("First provider failed") return mock_result with patch.object(provider_router, "get_providers", return_value=mock_providers): with patch("app.services.adapters.AdapterRegistry.get") as mock_get: mock_adapter_class = MagicMock() mock_adapter_instance = MagicMock() mock_adapter_instance.execute = mock_execute mock_adapter_class.return_value = mock_adapter_instance mock_get.return_value = mock_adapter_class result = await provider_router.generate_story_content( input_type="keywords", data="测试", ) assert result == mock_result assert call_count == 2 # 第一个失败,第二个成功 @pytest.mark.asyncio async def test_all_providers_fail(self): """所有 provider 都失败时抛出异常。""" from app.services import provider_router mock_provider = MagicMock() mock_provider.configure_mock( id="provider-1", type="text", adapter="text_primary", api_key="key1", api_base=None, model=None, timeout_ms=60000, max_retries=3, config_ref=None, config_json={}, priority=10, weight=1.0, enabled=True, ) mock_providers = [mock_provider] async def mock_execute(**kwargs): raise Exception("Provider failed") with patch.object(provider_router, "get_providers", return_value=mock_providers): with patch("app.services.adapters.AdapterRegistry.get") as mock_get: mock_adapter_class = MagicMock() mock_adapter_instance = MagicMock() mock_adapter_instance.execute = mock_execute mock_adapter_class.return_value = mock_adapter_instance mock_get.return_value = mock_adapter_class with pytest.raises(ValueError, match="No text provider succeeded"): await provider_router.generate_story_content( input_type="keywords", data="测试", ) @pytest.mark.asyncio async def test_default_provider_skips_fk_backed_metrics(self): """环境变量/default provider 没有 providers 表记录,不写带外键的指标表。""" from app.services import provider_router mock_result = StoryOutput( mode="generated", title="本地演示故事", story_text="内容", cover_prompt_suggestion="prompt", ) class MockAdapter: estimated_cost = 0.0 def __init__(self, config): self.config = config async def execute(self, **kwargs): return mock_result with patch.object( provider_router, "_get_providers_with_config", new_callable=AsyncMock, ) as mock_providers: mock_providers.return_value = [("demo", AdapterConfig(api_key=""), None)] with patch.object(provider_router.AdapterRegistry, "get", return_value=MockAdapter): with patch.object( provider_router.health_checker, "is_healthy", new_callable=AsyncMock, ) as mock_is_healthy: with patch.object( provider_router.metrics_collector, "record_call", new_callable=AsyncMock, ) as mock_record_call: with patch.object( provider_router.health_checker, "record_call_result", new_callable=AsyncMock, ) as mock_record_call_result: with patch.object( provider_router.cost_tracker, "record_cost", new_callable=AsyncMock, ) as mock_record_cost: result = await provider_router._route_with_failover( "text", db=AsyncMock(), user_id="user-1", input_type="keywords", data="测试", ) assert result == mock_result mock_is_healthy.assert_not_called() mock_record_call.assert_not_called() mock_record_call_result.assert_not_called() mock_record_cost.assert_awaited_once() assert mock_record_cost.await_args.kwargs["provider_id"] is None @pytest.mark.asyncio async def test_storybook_round_robin_strategy_is_supported(self): """所有能力都应能使用 routing policy,storybook 不能漏掉轮询计数器。""" from app.services import provider_router from app.services.adapters.storybook.primary import Storybook mock_storybook = Storybook( title="轮询绘本", main_character="小星", art_style="温暖水彩", pages=[], cover_prompt="cover", ) class MockAdapter: estimated_cost = 0.0 def __init__(self, config): self.config = config async def execute(self, **kwargs): return mock_storybook with patch.object( provider_router, "_get_providers_with_config", new_callable=AsyncMock, ) as mock_providers: mock_providers.return_value = [ ("storybook_primary", AdapterConfig(api_key=""), None), ] with patch.object(provider_router.AdapterRegistry, "get", return_value=MockAdapter): result = await provider_router.generate_storybook( keywords="测试", strategy=RoutingStrategy.ROUND_ROBIN, ) assert result == mock_storybook class TestProviderPolicy: """Provider capability / routing policy boundary tests.""" def test_policy_lists_all_capabilities(self): policies = list_capability_policies() capabilities = {item["capability"] for item in policies} assert capabilities == {"text", "image", "tts", "storybook", "asr"} assert DEFAULT_PROVIDERS["storybook"] == ["storybook_primary"] assert DEFAULT_PROVIDERS["asr"] == ["demo"] def test_demo_provider_only_added_to_supported_capabilities(self): settings = SimpleNamespace( text_providers=["gemini"], image_providers=["cqtai"], tts_providers=["edge_tts"], storybook_providers=["storybook_primary"], asr_providers=["openai_asr"], enable_demo_providers=True, ) assert get_provider_names_from_settings("text", settings) == ["demo", "gemini"] assert get_provider_names_from_settings("image", settings) == ["demo", "cqtai"] assert get_provider_names_from_settings("storybook", settings) == [ "demo", "storybook_primary", ] assert get_provider_names_from_settings("tts", settings) == ["edge_tts"] assert get_provider_names_from_settings("asr", settings) == ["demo", "openai_asr"] def test_policy_defaults_when_settings_lists_are_empty(self): settings = SimpleNamespace( text_providers=[], image_providers=[], tts_providers=[], storybook_providers=[], asr_providers=[], enable_demo_providers=False, ) assert get_provider_names_from_settings("text", settings) == ["gemini", "openai"] assert get_provider_names_from_settings("tts", settings) == [ "minimax", "elevenlabs", "edge_tts", ] assert get_provider_names_from_settings("asr", settings) == ["demo"] @pytest.mark.asyncio async def test_asr_demo_provider_uses_transcript_hint(self): from app.services import provider_router result = await provider_router.transcribe_audio( audio_bytes=b"fake-audio", file_name="turn.webm", mime_type="audio/webm", transcript_hint="我想听一个小熊找星星的故事", ) assert result.transcript_text == "我想听一个小熊找星星的故事" assert result.confidence == 1.0 assert result.provider == "demo" def test_openai_asr_default_config_uses_openai_env(self): from app.services.provider_router import _get_default_config with patch("app.services.provider_router.settings") as mock_settings: mock_settings.openai_api_key = "openai-key" mock_settings.openai_api_base = "https://api.example.com/v1" mock_settings.voice_transcription_model = "gpt-4o-mini-transcribe" config = _get_default_config("openai_asr") assert config is not None assert config.api_key == "openai-key" assert config.api_base == "https://api.example.com/v1" assert config.model == "gpt-4o-mini-transcribe" class TestProviderConfigFromDB: """从 DB 加载 provider 配置测试。""" def test_build_config_from_provider_with_api_key(self): """Provider 有 api_key 时优先使用。""" from app.services.provider_router import _build_config_from_provider mock_provider = MagicMock() mock_provider.adapter = "text_primary" mock_provider.api_key = "db-api-key" mock_provider.api_base = "https://custom.api.com" mock_provider.model = "custom-model" mock_provider.timeout_ms = 30000 mock_provider.max_retries = 5 mock_provider.config_ref = None mock_provider.config_json = {} config = _build_config_from_provider(mock_provider) assert config.api_key == "db-api-key" assert config.api_base == "https://custom.api.com" assert config.model == "custom-model" assert config.timeout_ms == 30000 assert config.max_retries == 5 def test_build_config_fallback_to_settings(self): """Provider 无 api_key 时回退到 settings。""" from app.services.provider_router import _build_config_from_provider mock_provider = MagicMock() mock_provider.adapter = "text_primary" mock_provider.api_key = None mock_provider.api_base = None mock_provider.model = None mock_provider.timeout_ms = None mock_provider.max_retries = None mock_provider.config_ref = "text_api_key" mock_provider.config_json = {} with patch("app.services.provider_router.settings") as mock_settings: mock_settings.text_api_key = "settings-api-key" mock_settings.text_model = "gemini-2.0-flash" config = _build_config_from_provider(mock_provider) assert config.api_key == "settings-api-key" def test_build_config_uses_direct_config_ref_name(self): """config_ref 可以直接使用 settings 字段名,便于后台配置。""" from app.services.provider_router import _build_config_from_provider mock_provider = MagicMock() mock_provider.adapter = "antigravity" mock_provider.api_key = None mock_provider.api_base = None mock_provider.model = None mock_provider.timeout_ms = None mock_provider.max_retries = None mock_provider.config_ref = "antigravity_api_key" mock_provider.config_json = {} with patch("app.services.provider_router.settings") as mock_settings: mock_settings.antigravity_api_key = "antigravity-key" mock_settings.antigravity_api_base = "https://antigravity.example" mock_settings.antigravity_model = "gemini-3-pro-image" config = _build_config_from_provider(mock_provider) assert config.api_key == "antigravity-key" assert config.api_base == "https://antigravity.example" assert config.model == "gemini-3-pro-image" class TestProviderCacheStartup: """Provider cache 启动加载测试。""" @pytest.mark.asyncio async def test_cache_loaded_on_startup(self): """启动时加载 provider cache。""" from app.main import _load_provider_cache with patch("app.db.database._get_session_factory") as mock_factory: mock_session = AsyncMock() mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session) mock_factory.return_value.__aexit__ = AsyncMock() with patch( "app.services.provider_cache.reload_providers", new_callable=AsyncMock, ) as mock_reload: mock_reload.return_value = {"text": [], "image": [], "tts": []} await _load_provider_cache() mock_reload.assert_called_once()