Files
dreamweaver/backend/tests/test_provider_router.py

399 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Provider router 测试 - failover 和配置加载。"""
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.services.adapters import AdapterConfig
from app.services.adapters.text.models import StoryOutput
from app.services.provider_policy import (
DEFAULT_PROVIDERS,
RoutingStrategy,
get_provider_names_from_settings,
list_capability_policies,
)
class TestProviderFailover:
"""Provider failover 测试。"""
@pytest.mark.asyncio
async def test_failover_to_second_provider(self):
"""第一个 provider 失败时切换到第二个。"""
from app.services import provider_router
# Mock 两个 provider - 使用 spec=False 并显式设置所有属性
mock_provider_1 = MagicMock()
mock_provider_1.configure_mock(
id="provider-1",
type="text",
adapter="text_primary",
api_key="key1",
api_base=None,
model=None,
timeout_ms=60000,
max_retries=3,
config_ref=None,
config_json={},
priority=10,
weight=1.0,
enabled=True,
)
mock_provider_2 = MagicMock()
mock_provider_2.configure_mock(
id="provider-2",
type="text",
adapter="text_primary",
api_key="key2",
api_base=None,
model=None,
timeout_ms=60000,
max_retries=3,
config_ref=None,
config_json={},
priority=5,
weight=1.0,
enabled=True,
)
mock_providers = [mock_provider_1, mock_provider_2]
mock_result = StoryOutput(
mode="generated",
title="测试故事",
story_text="内容",
cover_prompt_suggestion="prompt",
)
call_count = 0
async def mock_execute(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
raise Exception("First provider failed")
return mock_result
with patch.object(provider_router, "get_providers", return_value=mock_providers):
with patch("app.services.adapters.AdapterRegistry.get") as mock_get:
mock_adapter_class = MagicMock()
mock_adapter_instance = MagicMock()
mock_adapter_instance.execute = mock_execute
mock_adapter_class.return_value = mock_adapter_instance
mock_get.return_value = mock_adapter_class
result = await provider_router.generate_story_content(
input_type="keywords",
data="测试",
)
assert result == mock_result
assert call_count == 2 # 第一个失败,第二个成功
@pytest.mark.asyncio
async def test_all_providers_fail(self):
"""所有 provider 都失败时抛出异常。"""
from app.services import provider_router
mock_provider = MagicMock()
mock_provider.configure_mock(
id="provider-1",
type="text",
adapter="text_primary",
api_key="key1",
api_base=None,
model=None,
timeout_ms=60000,
max_retries=3,
config_ref=None,
config_json={},
priority=10,
weight=1.0,
enabled=True,
)
mock_providers = [mock_provider]
async def mock_execute(**kwargs):
raise Exception("Provider failed")
with patch.object(provider_router, "get_providers", return_value=mock_providers):
with patch("app.services.adapters.AdapterRegistry.get") as mock_get:
mock_adapter_class = MagicMock()
mock_adapter_instance = MagicMock()
mock_adapter_instance.execute = mock_execute
mock_adapter_class.return_value = mock_adapter_instance
mock_get.return_value = mock_adapter_class
with pytest.raises(ValueError, match="No text provider succeeded"):
await provider_router.generate_story_content(
input_type="keywords",
data="测试",
)
@pytest.mark.asyncio
async def test_default_provider_skips_fk_backed_metrics(self):
"""环境变量/default provider 没有 providers 表记录,不写带外键的指标表。"""
from app.services import provider_router
mock_result = StoryOutput(
mode="generated",
title="本地演示故事",
story_text="内容",
cover_prompt_suggestion="prompt",
)
class MockAdapter:
estimated_cost = 0.0
def __init__(self, config):
self.config = config
async def execute(self, **kwargs):
return mock_result
with patch.object(
provider_router,
"_get_providers_with_config",
new_callable=AsyncMock,
) as mock_providers:
mock_providers.return_value = [("demo", AdapterConfig(api_key=""), None)]
with patch.object(provider_router.AdapterRegistry, "get", return_value=MockAdapter):
with patch.object(
provider_router.health_checker,
"is_healthy",
new_callable=AsyncMock,
) as mock_is_healthy:
with patch.object(
provider_router.metrics_collector,
"record_call",
new_callable=AsyncMock,
) as mock_record_call:
with patch.object(
provider_router.health_checker,
"record_call_result",
new_callable=AsyncMock,
) as mock_record_call_result:
with patch.object(
provider_router.cost_tracker,
"record_cost",
new_callable=AsyncMock,
) as mock_record_cost:
result = await provider_router._route_with_failover(
"text",
db=AsyncMock(),
user_id="user-1",
input_type="keywords",
data="测试",
)
assert result == mock_result
mock_is_healthy.assert_not_called()
mock_record_call.assert_not_called()
mock_record_call_result.assert_not_called()
mock_record_cost.assert_awaited_once()
assert mock_record_cost.await_args.kwargs["provider_id"] is None
@pytest.mark.asyncio
async def test_storybook_round_robin_strategy_is_supported(self):
"""所有能力都应能使用 routing policystorybook 不能漏掉轮询计数器。"""
from app.services import provider_router
from app.services.adapters.storybook.primary import Storybook
mock_storybook = Storybook(
title="轮询绘本",
main_character="小星",
art_style="温暖水彩",
pages=[],
cover_prompt="cover",
)
class MockAdapter:
estimated_cost = 0.0
def __init__(self, config):
self.config = config
async def execute(self, **kwargs):
return mock_storybook
with patch.object(
provider_router,
"_get_providers_with_config",
new_callable=AsyncMock,
) as mock_providers:
mock_providers.return_value = [
("storybook_primary", AdapterConfig(api_key=""), None),
]
with patch.object(provider_router.AdapterRegistry, "get", return_value=MockAdapter):
result = await provider_router.generate_storybook(
keywords="测试",
strategy=RoutingStrategy.ROUND_ROBIN,
)
assert result == mock_storybook
class TestProviderPolicy:
"""Provider capability / routing policy boundary tests."""
def test_policy_lists_all_capabilities(self):
policies = list_capability_policies()
capabilities = {item["capability"] for item in policies}
assert capabilities == {"text", "image", "tts", "storybook", "asr"}
assert DEFAULT_PROVIDERS["storybook"] == ["storybook_primary"]
assert DEFAULT_PROVIDERS["asr"] == ["demo"]
def test_demo_provider_only_added_to_supported_capabilities(self):
settings = SimpleNamespace(
text_providers=["gemini"],
image_providers=["cqtai"],
tts_providers=["edge_tts"],
storybook_providers=["storybook_primary"],
asr_providers=["openai_asr"],
enable_demo_providers=True,
)
assert get_provider_names_from_settings("text", settings) == ["demo", "gemini"]
assert get_provider_names_from_settings("image", settings) == ["demo", "cqtai"]
assert get_provider_names_from_settings("storybook", settings) == [
"demo",
"storybook_primary",
]
assert get_provider_names_from_settings("tts", settings) == ["edge_tts"]
assert get_provider_names_from_settings("asr", settings) == ["demo", "openai_asr"]
def test_policy_defaults_when_settings_lists_are_empty(self):
settings = SimpleNamespace(
text_providers=[],
image_providers=[],
tts_providers=[],
storybook_providers=[],
asr_providers=[],
enable_demo_providers=False,
)
assert get_provider_names_from_settings("text", settings) == ["gemini", "openai"]
assert get_provider_names_from_settings("tts", settings) == [
"minimax",
"elevenlabs",
"edge_tts",
]
assert get_provider_names_from_settings("asr", settings) == ["demo"]
@pytest.mark.asyncio
async def test_asr_demo_provider_uses_transcript_hint(self):
from app.services import provider_router
result = await provider_router.transcribe_audio(
audio_bytes=b"fake-audio",
file_name="turn.webm",
mime_type="audio/webm",
transcript_hint="我想听一个小熊找星星的故事",
)
assert result.transcript_text == "我想听一个小熊找星星的故事"
assert result.confidence == 1.0
assert result.provider == "demo"
class TestProviderConfigFromDB:
"""从 DB 加载 provider 配置测试。"""
def test_build_config_from_provider_with_api_key(self):
"""Provider 有 api_key 时优先使用。"""
from app.services.provider_router import _build_config_from_provider
mock_provider = MagicMock()
mock_provider.adapter = "text_primary"
mock_provider.api_key = "db-api-key"
mock_provider.api_base = "https://custom.api.com"
mock_provider.model = "custom-model"
mock_provider.timeout_ms = 30000
mock_provider.max_retries = 5
mock_provider.config_ref = None
mock_provider.config_json = {}
config = _build_config_from_provider(mock_provider)
assert config.api_key == "db-api-key"
assert config.api_base == "https://custom.api.com"
assert config.model == "custom-model"
assert config.timeout_ms == 30000
assert config.max_retries == 5
def test_build_config_fallback_to_settings(self):
"""Provider 无 api_key 时回退到 settings。"""
from app.services.provider_router import _build_config_from_provider
mock_provider = MagicMock()
mock_provider.adapter = "text_primary"
mock_provider.api_key = None
mock_provider.api_base = None
mock_provider.model = None
mock_provider.timeout_ms = None
mock_provider.max_retries = None
mock_provider.config_ref = "text_api_key"
mock_provider.config_json = {}
with patch("app.services.provider_router.settings") as mock_settings:
mock_settings.text_api_key = "settings-api-key"
mock_settings.text_model = "gemini-2.0-flash"
config = _build_config_from_provider(mock_provider)
assert config.api_key == "settings-api-key"
def test_build_config_uses_direct_config_ref_name(self):
"""config_ref 可以直接使用 settings 字段名,便于后台配置。"""
from app.services.provider_router import _build_config_from_provider
mock_provider = MagicMock()
mock_provider.adapter = "antigravity"
mock_provider.api_key = None
mock_provider.api_base = None
mock_provider.model = None
mock_provider.timeout_ms = None
mock_provider.max_retries = None
mock_provider.config_ref = "antigravity_api_key"
mock_provider.config_json = {}
with patch("app.services.provider_router.settings") as mock_settings:
mock_settings.antigravity_api_key = "antigravity-key"
mock_settings.antigravity_api_base = "https://antigravity.example"
mock_settings.antigravity_model = "gemini-3-pro-image"
config = _build_config_from_provider(mock_provider)
assert config.api_key == "antigravity-key"
assert config.api_base == "https://antigravity.example"
assert config.model == "gemini-3-pro-image"
class TestProviderCacheStartup:
"""Provider cache 启动加载测试。"""
@pytest.mark.asyncio
async def test_cache_loaded_on_startup(self):
"""启动时加载 provider cache。"""
from app.main import _load_provider_cache
with patch("app.db.database._get_session_factory") as mock_factory:
mock_session = AsyncMock()
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_factory.return_value.__aexit__ = AsyncMock()
with patch(
"app.services.provider_cache.reload_providers",
new_callable=AsyncMock,
) as mock_reload:
mock_reload.return_value = {"text": [], "image": [], "tts": []}
await _load_provider_cache()
mock_reload.assert_called_once()