refactor: separate provider capability policy

This commit is contained in:
2026-04-18 13:37:59 +08:00
parent 0444b81df6
commit 7b8e7c9944
11 changed files with 393 additions and 88 deletions

View File

@@ -8,7 +8,7 @@ from app.db.admin_models import Provider
from app.db.database import get_db
from app.services.adapters.registry import AdapterRegistry
from app.services.cost_tracker import cost_tracker
from app.services.provider_router import DEFAULT_PROVIDERS
from app.services.provider_policy import DEFAULT_PROVIDERS, list_capability_policies
from app.services.secret_service import SecretService
router = APIRouter(dependencies=[Depends(admin_guard)])
@@ -68,6 +68,12 @@ async def get_env_defaults():
return DEFAULT_PROVIDERS
@router.get("/providers/capabilities")
async def list_provider_capabilities():
"""获取 Provider 能力分层与默认路由策略。"""
return list_capability_policies()
@router.get("/providers", response_model=list[ProviderResponse])
async def list_providers(db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Provider))

View File

@@ -1,22 +1,20 @@
"""Redis-backed cache for providers loaded from DB."""
import json
from collections import defaultdict
from typing import Literal
from pydantic import BaseModel
import json
from collections import defaultdict
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.core.redis import get_redis
from app.db.admin_models import Provider
logger = get_logger(__name__)
ProviderType = Literal["text", "image", "tts", "storybook"]
from app.core.redis import get_redis
from app.db.admin_models import Provider
from app.services.provider_policy import ProviderType
logger = get_logger(__name__)
class CachedProvider(BaseModel):
"""Serializable provider configuration matching DB model fields."""

View File

@@ -0,0 +1,148 @@
"""Provider capability and routing policy definitions."""
from dataclasses import dataclass
from enum import Enum
from typing import Literal, Protocol, TypeAlias
ProviderType: TypeAlias = Literal["text", "image", "tts", "storybook"]
class RoutingStrategy(str, Enum):
"""How providers should be ordered before failover execution."""
PRIORITY = "priority"
COST = "cost"
LATENCY = "latency"
ROUND_ROBIN = "round_robin"
@dataclass(frozen=True)
class CapabilityPolicy:
"""Product-level capability policy for one provider family."""
capability: ProviderType
label: str
description: str
settings_attr: str
default_providers: tuple[str, ...]
default_strategy: RoutingStrategy = RoutingStrategy.PRIORITY
demo_provider: str | None = None
class ProviderSettings(Protocol):
"""Settings fields required by provider policy resolution."""
text_providers: list[str]
image_providers: list[str]
tts_providers: list[str]
storybook_providers: list[str]
enable_demo_providers: bool
CAPABILITY_POLICIES: dict[ProviderType, CapabilityPolicy] = {
"text": CapabilityPolicy(
capability="text",
label="文本生成",
description="生成或润色儿童故事文本。",
settings_attr="text_providers",
default_providers=("gemini", "openai"),
demo_provider="demo",
),
"image": CapabilityPolicy(
capability="image",
label="图片生成",
description="生成故事封面或绘本插图。",
settings_attr="image_providers",
default_providers=("cqtai",),
demo_provider="demo",
),
"tts": CapabilityPolicy(
capability="tts",
label="语音合成",
description="将故事文本合成为可播放音频。",
settings_attr="tts_providers",
default_providers=("minimax", "elevenlabs", "edge_tts"),
),
"storybook": CapabilityPolicy(
capability="storybook",
label="绘本结构生成",
description="生成多页绘本结构、分镜文本和插图提示词。",
settings_attr="storybook_providers",
default_providers=("storybook_primary",),
demo_provider="demo",
),
}
DEFAULT_PROVIDERS: dict[ProviderType, list[str]] = {
capability: list(policy.default_providers)
for capability, policy in CAPABILITY_POLICIES.items()
}
API_KEY_MAP: dict[str, str] = {
# Text
"gemini": "text_api_key",
"text_primary": "text_api_key",
"text_api_key": "text_api_key",
"openai": "openai_api_key",
"openai_api_key": "openai_api_key",
# Image
"cqtai": "cqtai_api_key",
"cqtai_api_key": "cqtai_api_key",
"antigravity": "antigravity_api_key",
"antigravity_api_key": "antigravity_api_key",
"image_primary": "image_api_key",
"image_api_key": "image_api_key",
# TTS
"minimax": "minimax_api_key",
"minimax_api_key": "minimax_api_key",
"elevenlabs": "elevenlabs_api_key",
"elevenlabs_api_key": "elevenlabs_api_key",
"edge_tts": "tts_api_key",
"tts_primary": "tts_api_key",
"tts_api_key": "tts_api_key",
}
def get_capability_policy(capability: ProviderType) -> CapabilityPolicy:
"""Return the product policy for a provider capability."""
return CAPABILITY_POLICIES[capability]
def get_provider_names_from_settings(
capability: ProviderType,
settings: ProviderSettings,
) -> list[str]:
"""Resolve provider order from settings, falling back to capability defaults."""
policy = get_capability_policy(capability)
configured = getattr(settings, policy.settings_attr, None)
names = list(configured or policy.default_providers)
if (
settings.enable_demo_providers
and policy.demo_provider
and policy.demo_provider not in names
):
names = [policy.demo_provider, *names]
return names
def list_capability_policies() -> list[dict[str, object]]:
"""Return a serializable capability policy overview for admin/docs use."""
return [
{
"capability": policy.capability,
"label": policy.label,
"description": policy.description,
"settings_attr": policy.settings_attr,
"default_providers": list(policy.default_providers),
"default_strategy": policy.default_strategy.value,
"demo_provider": policy.demo_provider,
}
for policy in CAPABILITY_POLICIES.values()
]

View File

@@ -1,7 +1,6 @@
"""Provider routing with failover - 基于适配器注册表的智能路由。"""
import time
from enum import Enum
from typing import TYPE_CHECKING, Literal, TypeVar
from sqlalchemy.ext.asyncio import AsyncSession
@@ -13,6 +12,13 @@ from app.services.adapters.text.models import StoryOutput
from app.services.cost_tracker import cost_tracker
from app.services.provider_cache import get_providers
from app.services.provider_metrics import health_checker, metrics_collector
from app.services.provider_policy import (
API_KEY_MAP,
DEFAULT_PROVIDERS,
ProviderType,
RoutingStrategy,
get_provider_names_from_settings,
)
if TYPE_CHECKING:
from app.db.admin_models import Provider
@@ -21,50 +27,9 @@ logger = get_logger(__name__)
T = TypeVar("T")
ProviderType = Literal["text", "image", "tts", "storybook"]
class RoutingStrategy(str, Enum):
"""路由策略枚举。"""
PRIORITY = "priority" # 按优先级排序(默认)
COST = "cost" # 按成本排序
LATENCY = "latency" # 按延迟排序
ROUND_ROBIN = "round_robin" # 轮询
# 默认配置映射(当 DB 无配置时使用)
# 这是“代码级”的默认策略,对应 .env 为空的情况
DEFAULT_PROVIDERS: dict[ProviderType, list[str]] = {
"text": ["gemini", "openai"],
"image": ["cqtai"],
"tts": ["minimax", "elevenlabs", "edge_tts"],
"storybook": ["storybook_primary"],
}
# API Key 映射adapter_name -> settings 属性名
API_KEY_MAP: dict[str, str] = {
# Text
"gemini": "text_api_key", # Gemini 还是复用 text_api_key 字段
"text_primary": "text_api_key", # 兼容旧别名
"openai": "openai_api_key",
# Image
"cqtai": "cqtai_api_key",
"image_primary": "image_api_key", # 兼容旧别名
# TTS
"minimax": "minimax_api_key",
"elevenlabs": "elevenlabs_api_key",
"edge_tts": "tts_api_key", # EdgeTTS 复用 tts_api_key (通常为空)
"tts_primary": "tts_api_key", # 兼容旧别名
}
# 轮询计数器
_round_robin_counters: dict[ProviderType, int] = {
"text": 0,
"image": 0,
"tts": 0,
provider_type: 0 for provider_type in DEFAULT_PROVIDERS
}
# 延迟缓存(内存中,简化实现)
@@ -115,6 +80,13 @@ def _get_default_config(adapter_name: str) -> AdapterConfig | None:
model=settings.image_model or "nano-banana-pro",
timeout_ms=120000,
)
if adapter_name == "antigravity":
return AdapterConfig(
api_key=getattr(settings, "antigravity_api_key", ""),
api_base=getattr(settings, "antigravity_api_base", ""),
model=settings.antigravity_model,
timeout_ms=120000,
)
if adapter_name == "image_primary":
# 如果还有地方在用 image_primary暂时映射到快或者其他
# 但既然我们全面整改,最好也删了。这里暂时保留一个空的 fallback 以防报错
@@ -196,15 +168,7 @@ async def _get_providers_with_config(
if db_providers:
return [(p.adapter, _build_config_from_provider(p), p) for p in db_providers]
settings_map = {
"text": settings.text_providers,
"image": settings.image_providers,
"tts": settings.tts_providers,
"storybook": settings.storybook_providers,
}
names = settings_map.get(provider_type) or DEFAULT_PROVIDERS[provider_type]
if settings.enable_demo_providers and "demo" not in names:
names = ["demo", *names]
names = get_provider_names_from_settings(provider_type, settings)
result = []
for name in names:

View File

@@ -1,14 +1,21 @@
"""Provider router 测试 - failover 和配置加载。"""
from unittest.mock import AsyncMock, MagicMock, patch
"""Provider router 测试 - failover 和配置加载。"""
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.services.adapters import AdapterConfig
from app.services.adapters.text.models import StoryOutput
from app.services.adapters.text.models import StoryOutput
from app.services.provider_policy import (
DEFAULT_PROVIDERS,
RoutingStrategy,
get_provider_names_from_settings,
list_capability_policies,
)
class TestProviderFailover:
class TestProviderFailover:
"""Provider failover 测试。"""
@pytest.mark.asyncio
@@ -126,7 +133,7 @@ class TestProviderFailover:
)
@pytest.mark.asyncio
async def test_default_provider_skips_fk_backed_metrics(self):
async def test_default_provider_skips_fk_backed_metrics(self):
"""环境变量/default provider 没有 providers 表记录,不写带外键的指标表。"""
from app.services import provider_router
@@ -187,7 +194,91 @@ class TestProviderFailover:
mock_record_call.assert_not_called()
mock_record_call_result.assert_not_called()
mock_record_cost.assert_awaited_once()
assert mock_record_cost.await_args.kwargs["provider_id"] is None
assert mock_record_cost.await_args.kwargs["provider_id"] is None
@pytest.mark.asyncio
async def test_storybook_round_robin_strategy_is_supported(self):
"""所有能力都应能使用 routing policystorybook 不能漏掉轮询计数器。"""
from app.services import provider_router
from app.services.adapters.storybook.primary import Storybook
mock_storybook = Storybook(
title="轮询绘本",
main_character="小星",
art_style="温暖水彩",
pages=[],
cover_prompt="cover",
)
class MockAdapter:
estimated_cost = 0.0
def __init__(self, config):
self.config = config
async def execute(self, **kwargs):
return mock_storybook
with patch.object(
provider_router,
"_get_providers_with_config",
new_callable=AsyncMock,
) as mock_providers:
mock_providers.return_value = [
("storybook_primary", AdapterConfig(api_key=""), None),
]
with patch.object(provider_router.AdapterRegistry, "get", return_value=MockAdapter):
result = await provider_router.generate_storybook(
keywords="测试",
strategy=RoutingStrategy.ROUND_ROBIN,
)
assert result == mock_storybook
class TestProviderPolicy:
"""Provider capability / routing policy boundary tests."""
def test_policy_lists_all_capabilities(self):
policies = list_capability_policies()
capabilities = {item["capability"] for item in policies}
assert capabilities == {"text", "image", "tts", "storybook"}
assert DEFAULT_PROVIDERS["storybook"] == ["storybook_primary"]
def test_demo_provider_only_added_to_supported_capabilities(self):
settings = SimpleNamespace(
text_providers=["gemini"],
image_providers=["cqtai"],
tts_providers=["edge_tts"],
storybook_providers=["storybook_primary"],
enable_demo_providers=True,
)
assert get_provider_names_from_settings("text", settings) == ["demo", "gemini"]
assert get_provider_names_from_settings("image", settings) == ["demo", "cqtai"]
assert get_provider_names_from_settings("storybook", settings) == [
"demo",
"storybook_primary",
]
assert get_provider_names_from_settings("tts", settings) == ["edge_tts"]
def test_policy_defaults_when_settings_lists_are_empty(self):
settings = SimpleNamespace(
text_providers=[],
image_providers=[],
tts_providers=[],
storybook_providers=[],
enable_demo_providers=False,
)
assert get_provider_names_from_settings("text", settings) == ["gemini", "openai"]
assert get_provider_names_from_settings("tts", settings) == [
"minimax",
"elevenlabs",
"edge_tts",
]
class TestProviderConfigFromDB:
@@ -215,10 +306,10 @@ class TestProviderConfigFromDB:
assert config.timeout_ms == 30000
assert config.max_retries == 5
def test_build_config_fallback_to_settings(self):
"""Provider 无 api_key 时回退到 settings。"""
from app.services.provider_router import _build_config_from_provider
def test_build_config_fallback_to_settings(self):
"""Provider 无 api_key 时回退到 settings。"""
from app.services.provider_router import _build_config_from_provider
mock_provider = MagicMock()
mock_provider.adapter = "text_primary"
mock_provider.api_key = None
@@ -234,8 +325,33 @@ class TestProviderConfigFromDB:
mock_settings.text_model = "gemini-2.0-flash"
config = _build_config_from_provider(mock_provider)
assert config.api_key == "settings-api-key"
assert config.api_key == "settings-api-key"
def test_build_config_uses_direct_config_ref_name(self):
"""config_ref 可以直接使用 settings 字段名,便于后台配置。"""
from app.services.provider_router import _build_config_from_provider
mock_provider = MagicMock()
mock_provider.adapter = "antigravity"
mock_provider.api_key = None
mock_provider.api_base = None
mock_provider.model = None
mock_provider.timeout_ms = None
mock_provider.max_retries = None
mock_provider.config_ref = "antigravity_api_key"
mock_provider.config_json = {}
with patch("app.services.provider_router.settings") as mock_settings:
mock_settings.antigravity_api_key = "antigravity-key"
mock_settings.antigravity_api_base = "https://antigravity.example"
mock_settings.antigravity_model = "gemini-3-pro-image"
config = _build_config_from_provider(mock_provider)
assert config.api_key == "antigravity-key"
assert config.api_base == "https://antigravity.example"
assert config.model == "gemini-3-pro-image"
class TestProviderCacheStartup: