refactor: separate provider capability policy
This commit is contained in:
@@ -1,14 +1,21 @@
|
||||
"""Provider router 测试 - failover 和配置加载。"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
"""Provider router 测试 - failover 和配置加载。"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.adapters import AdapterConfig
|
||||
from app.services.adapters.text.models import StoryOutput
|
||||
from app.services.adapters.text.models import StoryOutput
|
||||
from app.services.provider_policy import (
|
||||
DEFAULT_PROVIDERS,
|
||||
RoutingStrategy,
|
||||
get_provider_names_from_settings,
|
||||
list_capability_policies,
|
||||
)
|
||||
|
||||
|
||||
class TestProviderFailover:
|
||||
class TestProviderFailover:
|
||||
"""Provider failover 测试。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -126,7 +133,7 @@ class TestProviderFailover:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_provider_skips_fk_backed_metrics(self):
|
||||
async def test_default_provider_skips_fk_backed_metrics(self):
|
||||
"""环境变量/default provider 没有 providers 表记录,不写带外键的指标表。"""
|
||||
from app.services import provider_router
|
||||
|
||||
@@ -187,7 +194,91 @@ class TestProviderFailover:
|
||||
mock_record_call.assert_not_called()
|
||||
mock_record_call_result.assert_not_called()
|
||||
mock_record_cost.assert_awaited_once()
|
||||
assert mock_record_cost.await_args.kwargs["provider_id"] is None
|
||||
assert mock_record_cost.await_args.kwargs["provider_id"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_storybook_round_robin_strategy_is_supported(self):
|
||||
"""所有能力都应能使用 routing policy,storybook 不能漏掉轮询计数器。"""
|
||||
from app.services import provider_router
|
||||
from app.services.adapters.storybook.primary import Storybook
|
||||
|
||||
mock_storybook = Storybook(
|
||||
title="轮询绘本",
|
||||
main_character="小星",
|
||||
art_style="温暖水彩",
|
||||
pages=[],
|
||||
cover_prompt="cover",
|
||||
)
|
||||
|
||||
class MockAdapter:
|
||||
estimated_cost = 0.0
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
async def execute(self, **kwargs):
|
||||
return mock_storybook
|
||||
|
||||
with patch.object(
|
||||
provider_router,
|
||||
"_get_providers_with_config",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_providers:
|
||||
mock_providers.return_value = [
|
||||
("storybook_primary", AdapterConfig(api_key=""), None),
|
||||
]
|
||||
|
||||
with patch.object(provider_router.AdapterRegistry, "get", return_value=MockAdapter):
|
||||
result = await provider_router.generate_storybook(
|
||||
keywords="测试",
|
||||
strategy=RoutingStrategy.ROUND_ROBIN,
|
||||
)
|
||||
|
||||
assert result == mock_storybook
|
||||
|
||||
|
||||
class TestProviderPolicy:
|
||||
"""Provider capability / routing policy boundary tests."""
|
||||
|
||||
def test_policy_lists_all_capabilities(self):
|
||||
policies = list_capability_policies()
|
||||
capabilities = {item["capability"] for item in policies}
|
||||
|
||||
assert capabilities == {"text", "image", "tts", "storybook"}
|
||||
assert DEFAULT_PROVIDERS["storybook"] == ["storybook_primary"]
|
||||
|
||||
def test_demo_provider_only_added_to_supported_capabilities(self):
|
||||
settings = SimpleNamespace(
|
||||
text_providers=["gemini"],
|
||||
image_providers=["cqtai"],
|
||||
tts_providers=["edge_tts"],
|
||||
storybook_providers=["storybook_primary"],
|
||||
enable_demo_providers=True,
|
||||
)
|
||||
|
||||
assert get_provider_names_from_settings("text", settings) == ["demo", "gemini"]
|
||||
assert get_provider_names_from_settings("image", settings) == ["demo", "cqtai"]
|
||||
assert get_provider_names_from_settings("storybook", settings) == [
|
||||
"demo",
|
||||
"storybook_primary",
|
||||
]
|
||||
assert get_provider_names_from_settings("tts", settings) == ["edge_tts"]
|
||||
|
||||
def test_policy_defaults_when_settings_lists_are_empty(self):
|
||||
settings = SimpleNamespace(
|
||||
text_providers=[],
|
||||
image_providers=[],
|
||||
tts_providers=[],
|
||||
storybook_providers=[],
|
||||
enable_demo_providers=False,
|
||||
)
|
||||
|
||||
assert get_provider_names_from_settings("text", settings) == ["gemini", "openai"]
|
||||
assert get_provider_names_from_settings("tts", settings) == [
|
||||
"minimax",
|
||||
"elevenlabs",
|
||||
"edge_tts",
|
||||
]
|
||||
|
||||
|
||||
class TestProviderConfigFromDB:
|
||||
@@ -215,10 +306,10 @@ class TestProviderConfigFromDB:
|
||||
assert config.timeout_ms == 30000
|
||||
assert config.max_retries == 5
|
||||
|
||||
def test_build_config_fallback_to_settings(self):
|
||||
"""Provider 无 api_key 时回退到 settings。"""
|
||||
from app.services.provider_router import _build_config_from_provider
|
||||
|
||||
def test_build_config_fallback_to_settings(self):
|
||||
"""Provider 无 api_key 时回退到 settings。"""
|
||||
from app.services.provider_router import _build_config_from_provider
|
||||
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.adapter = "text_primary"
|
||||
mock_provider.api_key = None
|
||||
@@ -234,8 +325,33 @@ class TestProviderConfigFromDB:
|
||||
mock_settings.text_model = "gemini-2.0-flash"
|
||||
|
||||
config = _build_config_from_provider(mock_provider)
|
||||
|
||||
assert config.api_key == "settings-api-key"
|
||||
|
||||
assert config.api_key == "settings-api-key"
|
||||
|
||||
def test_build_config_uses_direct_config_ref_name(self):
|
||||
"""config_ref 可以直接使用 settings 字段名,便于后台配置。"""
|
||||
from app.services.provider_router import _build_config_from_provider
|
||||
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.adapter = "antigravity"
|
||||
mock_provider.api_key = None
|
||||
mock_provider.api_base = None
|
||||
mock_provider.model = None
|
||||
mock_provider.timeout_ms = None
|
||||
mock_provider.max_retries = None
|
||||
mock_provider.config_ref = "antigravity_api_key"
|
||||
mock_provider.config_json = {}
|
||||
|
||||
with patch("app.services.provider_router.settings") as mock_settings:
|
||||
mock_settings.antigravity_api_key = "antigravity-key"
|
||||
mock_settings.antigravity_api_base = "https://antigravity.example"
|
||||
mock_settings.antigravity_model = "gemini-3-pro-image"
|
||||
|
||||
config = _build_config_from_provider(mock_provider)
|
||||
|
||||
assert config.api_key == "antigravity-key"
|
||||
assert config.api_base == "https://antigravity.example"
|
||||
assert config.model == "gemini-3-pro-image"
|
||||
|
||||
|
||||
class TestProviderCacheStartup:
|
||||
|
||||
Reference in New Issue
Block a user