refactor: separate provider capability policy

This commit is contained in:
2026-04-18 13:37:59 +08:00
parent 0444b81df6
commit 7b8e7c9944
11 changed files with 393 additions and 88 deletions

View File

@@ -1,14 +1,21 @@
"""Provider router 测试 - failover 和配置加载。"""
from unittest.mock import AsyncMock, MagicMock, patch
"""Provider router 测试 - failover 和配置加载。"""
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.services.adapters import AdapterConfig
from app.services.adapters.text.models import StoryOutput
from app.services.adapters.text.models import StoryOutput
from app.services.provider_policy import (
DEFAULT_PROVIDERS,
RoutingStrategy,
get_provider_names_from_settings,
list_capability_policies,
)
class TestProviderFailover:
class TestProviderFailover:
"""Provider failover 测试。"""
@pytest.mark.asyncio
@@ -126,7 +133,7 @@ class TestProviderFailover:
)
@pytest.mark.asyncio
async def test_default_provider_skips_fk_backed_metrics(self):
async def test_default_provider_skips_fk_backed_metrics(self):
"""环境变量/default provider 没有 providers 表记录,不写带外键的指标表。"""
from app.services import provider_router
@@ -187,7 +194,91 @@ class TestProviderFailover:
mock_record_call.assert_not_called()
mock_record_call_result.assert_not_called()
mock_record_cost.assert_awaited_once()
assert mock_record_cost.await_args.kwargs["provider_id"] is None
assert mock_record_cost.await_args.kwargs["provider_id"] is None
@pytest.mark.asyncio
async def test_storybook_round_robin_strategy_is_supported(self):
"""所有能力都应能使用 routing policystorybook 不能漏掉轮询计数器。"""
from app.services import provider_router
from app.services.adapters.storybook.primary import Storybook
mock_storybook = Storybook(
title="轮询绘本",
main_character="小星",
art_style="温暖水彩",
pages=[],
cover_prompt="cover",
)
class MockAdapter:
estimated_cost = 0.0
def __init__(self, config):
self.config = config
async def execute(self, **kwargs):
return mock_storybook
with patch.object(
provider_router,
"_get_providers_with_config",
new_callable=AsyncMock,
) as mock_providers:
mock_providers.return_value = [
("storybook_primary", AdapterConfig(api_key=""), None),
]
with patch.object(provider_router.AdapterRegistry, "get", return_value=MockAdapter):
result = await provider_router.generate_storybook(
keywords="测试",
strategy=RoutingStrategy.ROUND_ROBIN,
)
assert result == mock_storybook
class TestProviderPolicy:
"""Provider capability / routing policy boundary tests."""
def test_policy_lists_all_capabilities(self):
policies = list_capability_policies()
capabilities = {item["capability"] for item in policies}
assert capabilities == {"text", "image", "tts", "storybook"}
assert DEFAULT_PROVIDERS["storybook"] == ["storybook_primary"]
def test_demo_provider_only_added_to_supported_capabilities(self):
settings = SimpleNamespace(
text_providers=["gemini"],
image_providers=["cqtai"],
tts_providers=["edge_tts"],
storybook_providers=["storybook_primary"],
enable_demo_providers=True,
)
assert get_provider_names_from_settings("text", settings) == ["demo", "gemini"]
assert get_provider_names_from_settings("image", settings) == ["demo", "cqtai"]
assert get_provider_names_from_settings("storybook", settings) == [
"demo",
"storybook_primary",
]
assert get_provider_names_from_settings("tts", settings) == ["edge_tts"]
def test_policy_defaults_when_settings_lists_are_empty(self):
settings = SimpleNamespace(
text_providers=[],
image_providers=[],
tts_providers=[],
storybook_providers=[],
enable_demo_providers=False,
)
assert get_provider_names_from_settings("text", settings) == ["gemini", "openai"]
assert get_provider_names_from_settings("tts", settings) == [
"minimax",
"elevenlabs",
"edge_tts",
]
class TestProviderConfigFromDB:
@@ -215,10 +306,10 @@ class TestProviderConfigFromDB:
assert config.timeout_ms == 30000
assert config.max_retries == 5
def test_build_config_fallback_to_settings(self):
"""Provider 无 api_key 时回退到 settings。"""
from app.services.provider_router import _build_config_from_provider
def test_build_config_fallback_to_settings(self):
"""Provider 无 api_key 时回退到 settings。"""
from app.services.provider_router import _build_config_from_provider
mock_provider = MagicMock()
mock_provider.adapter = "text_primary"
mock_provider.api_key = None
@@ -234,8 +325,33 @@ class TestProviderConfigFromDB:
mock_settings.text_model = "gemini-2.0-flash"
config = _build_config_from_provider(mock_provider)
assert config.api_key == "settings-api-key"
assert config.api_key == "settings-api-key"
def test_build_config_uses_direct_config_ref_name(self):
"""config_ref 可以直接使用 settings 字段名,便于后台配置。"""
from app.services.provider_router import _build_config_from_provider
mock_provider = MagicMock()
mock_provider.adapter = "antigravity"
mock_provider.api_key = None
mock_provider.api_base = None
mock_provider.model = None
mock_provider.timeout_ms = None
mock_provider.max_retries = None
mock_provider.config_ref = "antigravity_api_key"
mock_provider.config_json = {}
with patch("app.services.provider_router.settings") as mock_settings:
mock_settings.antigravity_api_key = "antigravity-key"
mock_settings.antigravity_api_base = "https://antigravity.example"
mock_settings.antigravity_model = "gemini-3-pro-image"
config = _build_config_from_provider(mock_provider)
assert config.api_key == "antigravity-key"
assert config.api_base == "https://antigravity.example"
assert config.model == "gemini-3-pro-image"
class TestProviderCacheStartup: