Files
dreamweaver/backend/tests/test_provider_router.py
Yuyan 44405ff7ac
Some checks failed
Build and Push Docker Images / changes (push) Has been cancelled
Build and Push Docker Images / build-backend (push) Has been cancelled
Build and Push Docker Images / build-frontend (push) Has been cancelled
Build and Push Docker Images / build-admin-frontend (push) Has been cancelled
feat: enable local docker demo mode
2026-04-18 12:01:27 +08:00

263 lines
9.6 KiB
Python

"""Provider router 测试 - failover 和配置加载。"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.services.adapters import AdapterConfig
from app.services.adapters.text.models import StoryOutput
class TestProviderFailover:
"""Provider failover 测试。"""
@pytest.mark.asyncio
async def test_failover_to_second_provider(self):
"""第一个 provider 失败时切换到第二个。"""
from app.services import provider_router
# Mock 两个 provider - 使用 spec=False 并显式设置所有属性
mock_provider_1 = MagicMock()
mock_provider_1.configure_mock(
id="provider-1",
type="text",
adapter="text_primary",
api_key="key1",
api_base=None,
model=None,
timeout_ms=60000,
max_retries=3,
config_ref=None,
config_json={},
priority=10,
weight=1.0,
enabled=True,
)
mock_provider_2 = MagicMock()
mock_provider_2.configure_mock(
id="provider-2",
type="text",
adapter="text_primary",
api_key="key2",
api_base=None,
model=None,
timeout_ms=60000,
max_retries=3,
config_ref=None,
config_json={},
priority=5,
weight=1.0,
enabled=True,
)
mock_providers = [mock_provider_1, mock_provider_2]
mock_result = StoryOutput(
mode="generated",
title="测试故事",
story_text="内容",
cover_prompt_suggestion="prompt",
)
call_count = 0
async def mock_execute(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
raise Exception("First provider failed")
return mock_result
with patch.object(provider_router, "get_providers", return_value=mock_providers):
with patch("app.services.adapters.AdapterRegistry.get") as mock_get:
mock_adapter_class = MagicMock()
mock_adapter_instance = MagicMock()
mock_adapter_instance.execute = mock_execute
mock_adapter_class.return_value = mock_adapter_instance
mock_get.return_value = mock_adapter_class
result = await provider_router.generate_story_content(
input_type="keywords",
data="测试",
)
assert result == mock_result
assert call_count == 2 # 第一个失败,第二个成功
@pytest.mark.asyncio
async def test_all_providers_fail(self):
"""所有 provider 都失败时抛出异常。"""
from app.services import provider_router
mock_provider = MagicMock()
mock_provider.configure_mock(
id="provider-1",
type="text",
adapter="text_primary",
api_key="key1",
api_base=None,
model=None,
timeout_ms=60000,
max_retries=3,
config_ref=None,
config_json={},
priority=10,
weight=1.0,
enabled=True,
)
mock_providers = [mock_provider]
async def mock_execute(**kwargs):
raise Exception("Provider failed")
with patch.object(provider_router, "get_providers", return_value=mock_providers):
with patch("app.services.adapters.AdapterRegistry.get") as mock_get:
mock_adapter_class = MagicMock()
mock_adapter_instance = MagicMock()
mock_adapter_instance.execute = mock_execute
mock_adapter_class.return_value = mock_adapter_instance
mock_get.return_value = mock_adapter_class
with pytest.raises(ValueError, match="No text provider succeeded"):
await provider_router.generate_story_content(
input_type="keywords",
data="测试",
)
@pytest.mark.asyncio
async def test_default_provider_skips_fk_backed_metrics(self):
"""环境变量/default provider 没有 providers 表记录,不写带外键的指标表。"""
from app.services import provider_router
mock_result = StoryOutput(
mode="generated",
title="本地演示故事",
story_text="内容",
cover_prompt_suggestion="prompt",
)
class MockAdapter:
estimated_cost = 0.0
def __init__(self, config):
self.config = config
async def execute(self, **kwargs):
return mock_result
with patch.object(
provider_router,
"_get_providers_with_config",
new_callable=AsyncMock,
) as mock_providers:
mock_providers.return_value = [("demo", AdapterConfig(api_key=""), None)]
with patch.object(provider_router.AdapterRegistry, "get", return_value=MockAdapter):
with patch.object(
provider_router.health_checker,
"is_healthy",
new_callable=AsyncMock,
) as mock_is_healthy:
with patch.object(
provider_router.metrics_collector,
"record_call",
new_callable=AsyncMock,
) as mock_record_call:
with patch.object(
provider_router.health_checker,
"record_call_result",
new_callable=AsyncMock,
) as mock_record_call_result:
with patch.object(
provider_router.cost_tracker,
"record_cost",
new_callable=AsyncMock,
) as mock_record_cost:
result = await provider_router._route_with_failover(
"text",
db=AsyncMock(),
user_id="user-1",
input_type="keywords",
data="测试",
)
assert result == mock_result
mock_is_healthy.assert_not_called()
mock_record_call.assert_not_called()
mock_record_call_result.assert_not_called()
mock_record_cost.assert_awaited_once()
assert mock_record_cost.await_args.kwargs["provider_id"] is None
class TestProviderConfigFromDB:
"""从 DB 加载 provider 配置测试。"""
def test_build_config_from_provider_with_api_key(self):
"""Provider 有 api_key 时优先使用。"""
from app.services.provider_router import _build_config_from_provider
mock_provider = MagicMock()
mock_provider.adapter = "text_primary"
mock_provider.api_key = "db-api-key"
mock_provider.api_base = "https://custom.api.com"
mock_provider.model = "custom-model"
mock_provider.timeout_ms = 30000
mock_provider.max_retries = 5
mock_provider.config_ref = None
mock_provider.config_json = {}
config = _build_config_from_provider(mock_provider)
assert config.api_key == "db-api-key"
assert config.api_base == "https://custom.api.com"
assert config.model == "custom-model"
assert config.timeout_ms == 30000
assert config.max_retries == 5
def test_build_config_fallback_to_settings(self):
"""Provider 无 api_key 时回退到 settings。"""
from app.services.provider_router import _build_config_from_provider
mock_provider = MagicMock()
mock_provider.adapter = "text_primary"
mock_provider.api_key = None
mock_provider.api_base = None
mock_provider.model = None
mock_provider.timeout_ms = None
mock_provider.max_retries = None
mock_provider.config_ref = "text_api_key"
mock_provider.config_json = {}
with patch("app.services.provider_router.settings") as mock_settings:
mock_settings.text_api_key = "settings-api-key"
mock_settings.text_model = "gemini-2.0-flash"
config = _build_config_from_provider(mock_provider)
assert config.api_key == "settings-api-key"
class TestProviderCacheStartup:
"""Provider cache 启动加载测试。"""
@pytest.mark.asyncio
async def test_cache_loaded_on_startup(self):
"""启动时加载 provider cache。"""
from app.main import _load_provider_cache
with patch("app.db.database._get_session_factory") as mock_factory:
mock_session = AsyncMock()
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_factory.return_value.__aexit__ = AsyncMock()
with patch(
"app.services.provider_cache.reload_providers",
new_callable=AsyncMock,
) as mock_reload:
mock_reload.return_value = {"text": [], "image": [], "tts": []}
await _load_provider_cache()
mock_reload.assert_called_once()