414 lines
15 KiB
Python
414 lines
15 KiB
Python
"""Provider router 测试 - failover 和配置加载。"""
|
||
|
||
from types import SimpleNamespace
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
import pytest
|
||
|
||
from app.services.adapters import AdapterConfig
|
||
from app.services.adapters.text.models import StoryOutput
|
||
from app.services.provider_policy import (
|
||
DEFAULT_PROVIDERS,
|
||
RoutingStrategy,
|
||
get_provider_names_from_settings,
|
||
list_capability_policies,
|
||
)
|
||
|
||
|
||
class TestProviderFailover:
|
||
"""Provider failover 测试。"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_failover_to_second_provider(self):
|
||
"""第一个 provider 失败时切换到第二个。"""
|
||
from app.services import provider_router
|
||
|
||
# Mock 两个 provider - 使用 spec=False 并显式设置所有属性
|
||
mock_provider_1 = MagicMock()
|
||
mock_provider_1.configure_mock(
|
||
id="provider-1",
|
||
type="text",
|
||
adapter="text_primary",
|
||
api_key="key1",
|
||
api_base=None,
|
||
model=None,
|
||
timeout_ms=60000,
|
||
max_retries=3,
|
||
config_ref=None,
|
||
config_json={},
|
||
priority=10,
|
||
weight=1.0,
|
||
enabled=True,
|
||
)
|
||
|
||
mock_provider_2 = MagicMock()
|
||
mock_provider_2.configure_mock(
|
||
id="provider-2",
|
||
type="text",
|
||
adapter="text_primary",
|
||
api_key="key2",
|
||
api_base=None,
|
||
model=None,
|
||
timeout_ms=60000,
|
||
max_retries=3,
|
||
config_ref=None,
|
||
config_json={},
|
||
priority=5,
|
||
weight=1.0,
|
||
enabled=True,
|
||
)
|
||
|
||
mock_providers = [mock_provider_1, mock_provider_2]
|
||
|
||
mock_result = StoryOutput(
|
||
mode="generated",
|
||
title="测试故事",
|
||
story_text="内容",
|
||
cover_prompt_suggestion="prompt",
|
||
)
|
||
|
||
call_count = 0
|
||
|
||
async def mock_execute(**kwargs):
|
||
nonlocal call_count
|
||
call_count += 1
|
||
if call_count == 1:
|
||
raise Exception("First provider failed")
|
||
return mock_result
|
||
|
||
with patch.object(provider_router, "get_providers", return_value=mock_providers):
|
||
with patch("app.services.adapters.AdapterRegistry.get") as mock_get:
|
||
mock_adapter_class = MagicMock()
|
||
mock_adapter_instance = MagicMock()
|
||
mock_adapter_instance.execute = mock_execute
|
||
mock_adapter_class.return_value = mock_adapter_instance
|
||
mock_get.return_value = mock_adapter_class
|
||
|
||
result = await provider_router.generate_story_content(
|
||
input_type="keywords",
|
||
data="测试",
|
||
)
|
||
|
||
assert result == mock_result
|
||
assert call_count == 2 # 第一个失败,第二个成功
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_all_providers_fail(self):
|
||
"""所有 provider 都失败时抛出异常。"""
|
||
from app.services import provider_router
|
||
|
||
mock_provider = MagicMock()
|
||
mock_provider.configure_mock(
|
||
id="provider-1",
|
||
type="text",
|
||
adapter="text_primary",
|
||
api_key="key1",
|
||
api_base=None,
|
||
model=None,
|
||
timeout_ms=60000,
|
||
max_retries=3,
|
||
config_ref=None,
|
||
config_json={},
|
||
priority=10,
|
||
weight=1.0,
|
||
enabled=True,
|
||
)
|
||
mock_providers = [mock_provider]
|
||
|
||
async def mock_execute(**kwargs):
|
||
raise Exception("Provider failed")
|
||
|
||
with patch.object(provider_router, "get_providers", return_value=mock_providers):
|
||
with patch("app.services.adapters.AdapterRegistry.get") as mock_get:
|
||
mock_adapter_class = MagicMock()
|
||
mock_adapter_instance = MagicMock()
|
||
mock_adapter_instance.execute = mock_execute
|
||
mock_adapter_class.return_value = mock_adapter_instance
|
||
mock_get.return_value = mock_adapter_class
|
||
|
||
with pytest.raises(ValueError, match="No text provider succeeded"):
|
||
await provider_router.generate_story_content(
|
||
input_type="keywords",
|
||
data="测试",
|
||
)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_default_provider_skips_fk_backed_metrics(self):
|
||
"""环境变量/default provider 没有 providers 表记录,不写带外键的指标表。"""
|
||
from app.services import provider_router
|
||
|
||
mock_result = StoryOutput(
|
||
mode="generated",
|
||
title="本地演示故事",
|
||
story_text="内容",
|
||
cover_prompt_suggestion="prompt",
|
||
)
|
||
|
||
class MockAdapter:
|
||
estimated_cost = 0.0
|
||
|
||
def __init__(self, config):
|
||
self.config = config
|
||
|
||
async def execute(self, **kwargs):
|
||
return mock_result
|
||
|
||
with patch.object(
|
||
provider_router,
|
||
"_get_providers_with_config",
|
||
new_callable=AsyncMock,
|
||
) as mock_providers:
|
||
mock_providers.return_value = [("demo", AdapterConfig(api_key=""), None)]
|
||
|
||
with patch.object(provider_router.AdapterRegistry, "get", return_value=MockAdapter):
|
||
with patch.object(
|
||
provider_router.health_checker,
|
||
"is_healthy",
|
||
new_callable=AsyncMock,
|
||
) as mock_is_healthy:
|
||
with patch.object(
|
||
provider_router.metrics_collector,
|
||
"record_call",
|
||
new_callable=AsyncMock,
|
||
) as mock_record_call:
|
||
with patch.object(
|
||
provider_router.health_checker,
|
||
"record_call_result",
|
||
new_callable=AsyncMock,
|
||
) as mock_record_call_result:
|
||
with patch.object(
|
||
provider_router.cost_tracker,
|
||
"record_cost",
|
||
new_callable=AsyncMock,
|
||
) as mock_record_cost:
|
||
result = await provider_router._route_with_failover(
|
||
"text",
|
||
db=AsyncMock(),
|
||
user_id="user-1",
|
||
input_type="keywords",
|
||
data="测试",
|
||
)
|
||
|
||
assert result == mock_result
|
||
mock_is_healthy.assert_not_called()
|
||
mock_record_call.assert_not_called()
|
||
mock_record_call_result.assert_not_called()
|
||
mock_record_cost.assert_awaited_once()
|
||
assert mock_record_cost.await_args.kwargs["provider_id"] is None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_storybook_round_robin_strategy_is_supported(self):
|
||
"""所有能力都应能使用 routing policy,storybook 不能漏掉轮询计数器。"""
|
||
from app.services import provider_router
|
||
from app.services.adapters.storybook.primary import Storybook
|
||
|
||
mock_storybook = Storybook(
|
||
title="轮询绘本",
|
||
main_character="小星",
|
||
art_style="温暖水彩",
|
||
pages=[],
|
||
cover_prompt="cover",
|
||
)
|
||
|
||
class MockAdapter:
|
||
estimated_cost = 0.0
|
||
|
||
def __init__(self, config):
|
||
self.config = config
|
||
|
||
async def execute(self, **kwargs):
|
||
return mock_storybook
|
||
|
||
with patch.object(
|
||
provider_router,
|
||
"_get_providers_with_config",
|
||
new_callable=AsyncMock,
|
||
) as mock_providers:
|
||
mock_providers.return_value = [
|
||
("storybook_primary", AdapterConfig(api_key=""), None),
|
||
]
|
||
|
||
with patch.object(provider_router.AdapterRegistry, "get", return_value=MockAdapter):
|
||
result = await provider_router.generate_storybook(
|
||
keywords="测试",
|
||
strategy=RoutingStrategy.ROUND_ROBIN,
|
||
)
|
||
|
||
assert result == mock_storybook
|
||
|
||
|
||
class TestProviderPolicy:
|
||
"""Provider capability / routing policy boundary tests."""
|
||
|
||
def test_policy_lists_all_capabilities(self):
|
||
policies = list_capability_policies()
|
||
capabilities = {item["capability"] for item in policies}
|
||
|
||
assert capabilities == {"text", "image", "tts", "storybook", "asr"}
|
||
assert DEFAULT_PROVIDERS["storybook"] == ["storybook_primary"]
|
||
assert DEFAULT_PROVIDERS["asr"] == ["demo"]
|
||
|
||
def test_demo_provider_only_added_to_supported_capabilities(self):
|
||
settings = SimpleNamespace(
|
||
text_providers=["gemini"],
|
||
image_providers=["cqtai"],
|
||
tts_providers=["edge_tts"],
|
||
storybook_providers=["storybook_primary"],
|
||
asr_providers=["openai_asr"],
|
||
enable_demo_providers=True,
|
||
)
|
||
|
||
assert get_provider_names_from_settings("text", settings) == ["demo", "gemini"]
|
||
assert get_provider_names_from_settings("image", settings) == ["demo", "cqtai"]
|
||
assert get_provider_names_from_settings("storybook", settings) == [
|
||
"demo",
|
||
"storybook_primary",
|
||
]
|
||
assert get_provider_names_from_settings("tts", settings) == ["edge_tts"]
|
||
assert get_provider_names_from_settings("asr", settings) == ["demo", "openai_asr"]
|
||
|
||
def test_policy_defaults_when_settings_lists_are_empty(self):
|
||
settings = SimpleNamespace(
|
||
text_providers=[],
|
||
image_providers=[],
|
||
tts_providers=[],
|
||
storybook_providers=[],
|
||
asr_providers=[],
|
||
enable_demo_providers=False,
|
||
)
|
||
|
||
assert get_provider_names_from_settings("text", settings) == ["gemini", "openai"]
|
||
assert get_provider_names_from_settings("tts", settings) == [
|
||
"minimax",
|
||
"elevenlabs",
|
||
"edge_tts",
|
||
]
|
||
assert get_provider_names_from_settings("asr", settings) == ["demo"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_asr_demo_provider_uses_transcript_hint(self):
|
||
from app.services import provider_router
|
||
|
||
result = await provider_router.transcribe_audio(
|
||
audio_bytes=b"fake-audio",
|
||
file_name="turn.webm",
|
||
mime_type="audio/webm",
|
||
transcript_hint="我想听一个小熊找星星的故事",
|
||
)
|
||
|
||
assert result.transcript_text == "我想听一个小熊找星星的故事"
|
||
assert result.confidence == 1.0
|
||
assert result.provider == "demo"
|
||
|
||
def test_openai_asr_default_config_uses_openai_env(self):
|
||
from app.services.provider_router import _get_default_config
|
||
|
||
with patch("app.services.provider_router.settings") as mock_settings:
|
||
mock_settings.openai_api_key = "openai-key"
|
||
mock_settings.openai_api_base = "https://api.example.com/v1"
|
||
mock_settings.voice_transcription_model = "gpt-4o-mini-transcribe"
|
||
|
||
config = _get_default_config("openai_asr")
|
||
|
||
assert config is not None
|
||
assert config.api_key == "openai-key"
|
||
assert config.api_base == "https://api.example.com/v1"
|
||
assert config.model == "gpt-4o-mini-transcribe"
|
||
|
||
|
||
class TestProviderConfigFromDB:
|
||
"""从 DB 加载 provider 配置测试。"""
|
||
|
||
def test_build_config_from_provider_with_api_key(self):
|
||
"""Provider 有 api_key 时优先使用。"""
|
||
from app.services.provider_router import _build_config_from_provider
|
||
|
||
mock_provider = MagicMock()
|
||
mock_provider.adapter = "text_primary"
|
||
mock_provider.api_key = "db-api-key"
|
||
mock_provider.api_base = "https://custom.api.com"
|
||
mock_provider.model = "custom-model"
|
||
mock_provider.timeout_ms = 30000
|
||
mock_provider.max_retries = 5
|
||
mock_provider.config_ref = None
|
||
mock_provider.config_json = {}
|
||
|
||
config = _build_config_from_provider(mock_provider)
|
||
|
||
assert config.api_key == "db-api-key"
|
||
assert config.api_base == "https://custom.api.com"
|
||
assert config.model == "custom-model"
|
||
assert config.timeout_ms == 30000
|
||
assert config.max_retries == 5
|
||
|
||
def test_build_config_fallback_to_settings(self):
|
||
"""Provider 无 api_key 时回退到 settings。"""
|
||
from app.services.provider_router import _build_config_from_provider
|
||
|
||
mock_provider = MagicMock()
|
||
mock_provider.adapter = "text_primary"
|
||
mock_provider.api_key = None
|
||
mock_provider.api_base = None
|
||
mock_provider.model = None
|
||
mock_provider.timeout_ms = None
|
||
mock_provider.max_retries = None
|
||
mock_provider.config_ref = "text_api_key"
|
||
mock_provider.config_json = {}
|
||
|
||
with patch("app.services.provider_router.settings") as mock_settings:
|
||
mock_settings.text_api_key = "settings-api-key"
|
||
mock_settings.text_model = "gemini-2.0-flash"
|
||
|
||
config = _build_config_from_provider(mock_provider)
|
||
|
||
assert config.api_key == "settings-api-key"
|
||
|
||
def test_build_config_uses_direct_config_ref_name(self):
|
||
"""config_ref 可以直接使用 settings 字段名,便于后台配置。"""
|
||
from app.services.provider_router import _build_config_from_provider
|
||
|
||
mock_provider = MagicMock()
|
||
mock_provider.adapter = "antigravity"
|
||
mock_provider.api_key = None
|
||
mock_provider.api_base = None
|
||
mock_provider.model = None
|
||
mock_provider.timeout_ms = None
|
||
mock_provider.max_retries = None
|
||
mock_provider.config_ref = "antigravity_api_key"
|
||
mock_provider.config_json = {}
|
||
|
||
with patch("app.services.provider_router.settings") as mock_settings:
|
||
mock_settings.antigravity_api_key = "antigravity-key"
|
||
mock_settings.antigravity_api_base = "https://antigravity.example"
|
||
mock_settings.antigravity_model = "gemini-3-pro-image"
|
||
|
||
config = _build_config_from_provider(mock_provider)
|
||
|
||
assert config.api_key == "antigravity-key"
|
||
assert config.api_base == "https://antigravity.example"
|
||
assert config.model == "gemini-3-pro-image"
|
||
|
||
|
||
class TestProviderCacheStartup:
|
||
"""Provider cache 启动加载测试。"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_cache_loaded_on_startup(self):
|
||
"""启动时加载 provider cache。"""
|
||
from app.main import _load_provider_cache
|
||
|
||
with patch("app.db.database._get_session_factory") as mock_factory:
|
||
mock_session = AsyncMock()
|
||
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
||
mock_factory.return_value.__aexit__ = AsyncMock()
|
||
|
||
with patch(
|
||
"app.services.provider_cache.reload_providers",
|
||
new_callable=AsyncMock,
|
||
) as mock_reload:
|
||
mock_reload.return_value = {"text": [], "image": [], "tts": []}
|
||
|
||
await _load_provider_cache()
|
||
|
||
mock_reload.assert_called_once()
|