refactor: separate provider capability policy
This commit is contained in:
@@ -8,7 +8,7 @@ from app.db.admin_models import Provider
|
||||
from app.db.database import get_db
|
||||
from app.services.adapters.registry import AdapterRegistry
|
||||
from app.services.cost_tracker import cost_tracker
|
||||
from app.services.provider_router import DEFAULT_PROVIDERS
|
||||
from app.services.provider_policy import DEFAULT_PROVIDERS, list_capability_policies
|
||||
from app.services.secret_service import SecretService
|
||||
|
||||
router = APIRouter(dependencies=[Depends(admin_guard)])
|
||||
@@ -68,6 +68,12 @@ async def get_env_defaults():
|
||||
return DEFAULT_PROVIDERS
|
||||
|
||||
|
||||
@router.get("/providers/capabilities")
|
||||
async def list_provider_capabilities():
|
||||
"""获取 Provider 能力分层与默认路由策略。"""
|
||||
return list_capability_policies()
|
||||
|
||||
|
||||
@router.get("/providers", response_model=list[ProviderResponse])
|
||||
async def list_providers(db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Provider))
|
||||
|
||||
@@ -1,22 +1,20 @@
|
||||
"""Redis-backed cache for providers loaded from DB."""
|
||||
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
import json
|
||||
from collections import defaultdict
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.core.redis import get_redis
|
||||
from app.db.admin_models import Provider
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
ProviderType = Literal["text", "image", "tts", "storybook"]
|
||||
|
||||
|
||||
from app.core.redis import get_redis
|
||||
from app.db.admin_models import Provider
|
||||
from app.services.provider_policy import ProviderType
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class CachedProvider(BaseModel):
|
||||
"""Serializable provider configuration matching DB model fields."""
|
||||
|
||||
|
||||
148
backend/app/services/provider_policy.py
Normal file
148
backend/app/services/provider_policy.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""Provider capability and routing policy definitions."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Literal, Protocol, TypeAlias
|
||||
|
||||
ProviderType: TypeAlias = Literal["text", "image", "tts", "storybook"]
|
||||
|
||||
|
||||
class RoutingStrategy(str, Enum):
|
||||
"""How providers should be ordered before failover execution."""
|
||||
|
||||
PRIORITY = "priority"
|
||||
COST = "cost"
|
||||
LATENCY = "latency"
|
||||
ROUND_ROBIN = "round_robin"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CapabilityPolicy:
|
||||
"""Product-level capability policy for one provider family."""
|
||||
|
||||
capability: ProviderType
|
||||
label: str
|
||||
description: str
|
||||
settings_attr: str
|
||||
default_providers: tuple[str, ...]
|
||||
default_strategy: RoutingStrategy = RoutingStrategy.PRIORITY
|
||||
demo_provider: str | None = None
|
||||
|
||||
|
||||
class ProviderSettings(Protocol):
|
||||
"""Settings fields required by provider policy resolution."""
|
||||
|
||||
text_providers: list[str]
|
||||
image_providers: list[str]
|
||||
tts_providers: list[str]
|
||||
storybook_providers: list[str]
|
||||
enable_demo_providers: bool
|
||||
|
||||
|
||||
CAPABILITY_POLICIES: dict[ProviderType, CapabilityPolicy] = {
|
||||
"text": CapabilityPolicy(
|
||||
capability="text",
|
||||
label="文本生成",
|
||||
description="生成或润色儿童故事文本。",
|
||||
settings_attr="text_providers",
|
||||
default_providers=("gemini", "openai"),
|
||||
demo_provider="demo",
|
||||
),
|
||||
"image": CapabilityPolicy(
|
||||
capability="image",
|
||||
label="图片生成",
|
||||
description="生成故事封面或绘本插图。",
|
||||
settings_attr="image_providers",
|
||||
default_providers=("cqtai",),
|
||||
demo_provider="demo",
|
||||
),
|
||||
"tts": CapabilityPolicy(
|
||||
capability="tts",
|
||||
label="语音合成",
|
||||
description="将故事文本合成为可播放音频。",
|
||||
settings_attr="tts_providers",
|
||||
default_providers=("minimax", "elevenlabs", "edge_tts"),
|
||||
),
|
||||
"storybook": CapabilityPolicy(
|
||||
capability="storybook",
|
||||
label="绘本结构生成",
|
||||
description="生成多页绘本结构、分镜文本和插图提示词。",
|
||||
settings_attr="storybook_providers",
|
||||
default_providers=("storybook_primary",),
|
||||
demo_provider="demo",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
DEFAULT_PROVIDERS: dict[ProviderType, list[str]] = {
|
||||
capability: list(policy.default_providers)
|
||||
for capability, policy in CAPABILITY_POLICIES.items()
|
||||
}
|
||||
|
||||
|
||||
API_KEY_MAP: dict[str, str] = {
|
||||
# Text
|
||||
"gemini": "text_api_key",
|
||||
"text_primary": "text_api_key",
|
||||
"text_api_key": "text_api_key",
|
||||
"openai": "openai_api_key",
|
||||
"openai_api_key": "openai_api_key",
|
||||
# Image
|
||||
"cqtai": "cqtai_api_key",
|
||||
"cqtai_api_key": "cqtai_api_key",
|
||||
"antigravity": "antigravity_api_key",
|
||||
"antigravity_api_key": "antigravity_api_key",
|
||||
"image_primary": "image_api_key",
|
||||
"image_api_key": "image_api_key",
|
||||
# TTS
|
||||
"minimax": "minimax_api_key",
|
||||
"minimax_api_key": "minimax_api_key",
|
||||
"elevenlabs": "elevenlabs_api_key",
|
||||
"elevenlabs_api_key": "elevenlabs_api_key",
|
||||
"edge_tts": "tts_api_key",
|
||||
"tts_primary": "tts_api_key",
|
||||
"tts_api_key": "tts_api_key",
|
||||
}
|
||||
|
||||
|
||||
def get_capability_policy(capability: ProviderType) -> CapabilityPolicy:
|
||||
"""Return the product policy for a provider capability."""
|
||||
|
||||
return CAPABILITY_POLICIES[capability]
|
||||
|
||||
|
||||
def get_provider_names_from_settings(
|
||||
capability: ProviderType,
|
||||
settings: ProviderSettings,
|
||||
) -> list[str]:
|
||||
"""Resolve provider order from settings, falling back to capability defaults."""
|
||||
|
||||
policy = get_capability_policy(capability)
|
||||
configured = getattr(settings, policy.settings_attr, None)
|
||||
names = list(configured or policy.default_providers)
|
||||
|
||||
if (
|
||||
settings.enable_demo_providers
|
||||
and policy.demo_provider
|
||||
and policy.demo_provider not in names
|
||||
):
|
||||
names = [policy.demo_provider, *names]
|
||||
|
||||
return names
|
||||
|
||||
|
||||
def list_capability_policies() -> list[dict[str, object]]:
|
||||
"""Return a serializable capability policy overview for admin/docs use."""
|
||||
|
||||
return [
|
||||
{
|
||||
"capability": policy.capability,
|
||||
"label": policy.label,
|
||||
"description": policy.description,
|
||||
"settings_attr": policy.settings_attr,
|
||||
"default_providers": list(policy.default_providers),
|
||||
"default_strategy": policy.default_strategy.value,
|
||||
"demo_provider": policy.demo_provider,
|
||||
}
|
||||
for policy in CAPABILITY_POLICIES.values()
|
||||
]
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Provider routing with failover - 基于适配器注册表的智能路由。"""
|
||||
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Literal, TypeVar
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -13,6 +12,13 @@ from app.services.adapters.text.models import StoryOutput
|
||||
from app.services.cost_tracker import cost_tracker
|
||||
from app.services.provider_cache import get_providers
|
||||
from app.services.provider_metrics import health_checker, metrics_collector
|
||||
from app.services.provider_policy import (
|
||||
API_KEY_MAP,
|
||||
DEFAULT_PROVIDERS,
|
||||
ProviderType,
|
||||
RoutingStrategy,
|
||||
get_provider_names_from_settings,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.db.admin_models import Provider
|
||||
@@ -21,50 +27,9 @@ logger = get_logger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
ProviderType = Literal["text", "image", "tts", "storybook"]
|
||||
|
||||
|
||||
class RoutingStrategy(str, Enum):
|
||||
"""路由策略枚举。"""
|
||||
|
||||
PRIORITY = "priority" # 按优先级排序(默认)
|
||||
COST = "cost" # 按成本排序
|
||||
LATENCY = "latency" # 按延迟排序
|
||||
ROUND_ROBIN = "round_robin" # 轮询
|
||||
|
||||
|
||||
# 默认配置映射(当 DB 无配置时使用)
|
||||
# 这是“代码级”的默认策略,对应 .env 为空的情况
|
||||
DEFAULT_PROVIDERS: dict[ProviderType, list[str]] = {
|
||||
"text": ["gemini", "openai"],
|
||||
"image": ["cqtai"],
|
||||
"tts": ["minimax", "elevenlabs", "edge_tts"],
|
||||
"storybook": ["storybook_primary"],
|
||||
}
|
||||
|
||||
# API Key 映射:adapter_name -> settings 属性名
|
||||
API_KEY_MAP: dict[str, str] = {
|
||||
# Text
|
||||
"gemini": "text_api_key", # Gemini 还是复用 text_api_key 字段
|
||||
"text_primary": "text_api_key", # 兼容旧别名
|
||||
"openai": "openai_api_key",
|
||||
|
||||
# Image
|
||||
"cqtai": "cqtai_api_key",
|
||||
"image_primary": "image_api_key", # 兼容旧别名
|
||||
|
||||
# TTS
|
||||
"minimax": "minimax_api_key",
|
||||
"elevenlabs": "elevenlabs_api_key",
|
||||
"edge_tts": "tts_api_key", # EdgeTTS 复用 tts_api_key (通常为空)
|
||||
"tts_primary": "tts_api_key", # 兼容旧别名
|
||||
}
|
||||
|
||||
# 轮询计数器
|
||||
_round_robin_counters: dict[ProviderType, int] = {
|
||||
"text": 0,
|
||||
"image": 0,
|
||||
"tts": 0,
|
||||
provider_type: 0 for provider_type in DEFAULT_PROVIDERS
|
||||
}
|
||||
|
||||
# 延迟缓存(内存中,简化实现)
|
||||
@@ -115,6 +80,13 @@ def _get_default_config(adapter_name: str) -> AdapterConfig | None:
|
||||
model=settings.image_model or "nano-banana-pro",
|
||||
timeout_ms=120000,
|
||||
)
|
||||
if adapter_name == "antigravity":
|
||||
return AdapterConfig(
|
||||
api_key=getattr(settings, "antigravity_api_key", ""),
|
||||
api_base=getattr(settings, "antigravity_api_base", ""),
|
||||
model=settings.antigravity_model,
|
||||
timeout_ms=120000,
|
||||
)
|
||||
if adapter_name == "image_primary":
|
||||
# 如果还有地方在用 image_primary,暂时映射到快或者其他
|
||||
# 但既然我们全面整改,最好也删了。这里暂时保留一个空的 fallback 以防报错
|
||||
@@ -196,15 +168,7 @@ async def _get_providers_with_config(
|
||||
if db_providers:
|
||||
return [(p.adapter, _build_config_from_provider(p), p) for p in db_providers]
|
||||
|
||||
settings_map = {
|
||||
"text": settings.text_providers,
|
||||
"image": settings.image_providers,
|
||||
"tts": settings.tts_providers,
|
||||
"storybook": settings.storybook_providers,
|
||||
}
|
||||
names = settings_map.get(provider_type) or DEFAULT_PROVIDERS[provider_type]
|
||||
if settings.enable_demo_providers and "demo" not in names:
|
||||
names = ["demo", *names]
|
||||
names = get_provider_names_from_settings(provider_type, settings)
|
||||
|
||||
result = []
|
||||
for name in names:
|
||||
|
||||
@@ -1,14 +1,21 @@
|
||||
"""Provider router 测试 - failover 和配置加载。"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
"""Provider router 测试 - failover 和配置加载。"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.adapters import AdapterConfig
|
||||
from app.services.adapters.text.models import StoryOutput
|
||||
from app.services.adapters.text.models import StoryOutput
|
||||
from app.services.provider_policy import (
|
||||
DEFAULT_PROVIDERS,
|
||||
RoutingStrategy,
|
||||
get_provider_names_from_settings,
|
||||
list_capability_policies,
|
||||
)
|
||||
|
||||
|
||||
class TestProviderFailover:
|
||||
class TestProviderFailover:
|
||||
"""Provider failover 测试。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -126,7 +133,7 @@ class TestProviderFailover:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_provider_skips_fk_backed_metrics(self):
|
||||
async def test_default_provider_skips_fk_backed_metrics(self):
|
||||
"""环境变量/default provider 没有 providers 表记录,不写带外键的指标表。"""
|
||||
from app.services import provider_router
|
||||
|
||||
@@ -187,7 +194,91 @@ class TestProviderFailover:
|
||||
mock_record_call.assert_not_called()
|
||||
mock_record_call_result.assert_not_called()
|
||||
mock_record_cost.assert_awaited_once()
|
||||
assert mock_record_cost.await_args.kwargs["provider_id"] is None
|
||||
assert mock_record_cost.await_args.kwargs["provider_id"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_storybook_round_robin_strategy_is_supported(self):
|
||||
"""所有能力都应能使用 routing policy,storybook 不能漏掉轮询计数器。"""
|
||||
from app.services import provider_router
|
||||
from app.services.adapters.storybook.primary import Storybook
|
||||
|
||||
mock_storybook = Storybook(
|
||||
title="轮询绘本",
|
||||
main_character="小星",
|
||||
art_style="温暖水彩",
|
||||
pages=[],
|
||||
cover_prompt="cover",
|
||||
)
|
||||
|
||||
class MockAdapter:
|
||||
estimated_cost = 0.0
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
async def execute(self, **kwargs):
|
||||
return mock_storybook
|
||||
|
||||
with patch.object(
|
||||
provider_router,
|
||||
"_get_providers_with_config",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_providers:
|
||||
mock_providers.return_value = [
|
||||
("storybook_primary", AdapterConfig(api_key=""), None),
|
||||
]
|
||||
|
||||
with patch.object(provider_router.AdapterRegistry, "get", return_value=MockAdapter):
|
||||
result = await provider_router.generate_storybook(
|
||||
keywords="测试",
|
||||
strategy=RoutingStrategy.ROUND_ROBIN,
|
||||
)
|
||||
|
||||
assert result == mock_storybook
|
||||
|
||||
|
||||
class TestProviderPolicy:
|
||||
"""Provider capability / routing policy boundary tests."""
|
||||
|
||||
def test_policy_lists_all_capabilities(self):
|
||||
policies = list_capability_policies()
|
||||
capabilities = {item["capability"] for item in policies}
|
||||
|
||||
assert capabilities == {"text", "image", "tts", "storybook"}
|
||||
assert DEFAULT_PROVIDERS["storybook"] == ["storybook_primary"]
|
||||
|
||||
def test_demo_provider_only_added_to_supported_capabilities(self):
|
||||
settings = SimpleNamespace(
|
||||
text_providers=["gemini"],
|
||||
image_providers=["cqtai"],
|
||||
tts_providers=["edge_tts"],
|
||||
storybook_providers=["storybook_primary"],
|
||||
enable_demo_providers=True,
|
||||
)
|
||||
|
||||
assert get_provider_names_from_settings("text", settings) == ["demo", "gemini"]
|
||||
assert get_provider_names_from_settings("image", settings) == ["demo", "cqtai"]
|
||||
assert get_provider_names_from_settings("storybook", settings) == [
|
||||
"demo",
|
||||
"storybook_primary",
|
||||
]
|
||||
assert get_provider_names_from_settings("tts", settings) == ["edge_tts"]
|
||||
|
||||
def test_policy_defaults_when_settings_lists_are_empty(self):
|
||||
settings = SimpleNamespace(
|
||||
text_providers=[],
|
||||
image_providers=[],
|
||||
tts_providers=[],
|
||||
storybook_providers=[],
|
||||
enable_demo_providers=False,
|
||||
)
|
||||
|
||||
assert get_provider_names_from_settings("text", settings) == ["gemini", "openai"]
|
||||
assert get_provider_names_from_settings("tts", settings) == [
|
||||
"minimax",
|
||||
"elevenlabs",
|
||||
"edge_tts",
|
||||
]
|
||||
|
||||
|
||||
class TestProviderConfigFromDB:
|
||||
@@ -215,10 +306,10 @@ class TestProviderConfigFromDB:
|
||||
assert config.timeout_ms == 30000
|
||||
assert config.max_retries == 5
|
||||
|
||||
def test_build_config_fallback_to_settings(self):
|
||||
"""Provider 无 api_key 时回退到 settings。"""
|
||||
from app.services.provider_router import _build_config_from_provider
|
||||
|
||||
def test_build_config_fallback_to_settings(self):
|
||||
"""Provider 无 api_key 时回退到 settings。"""
|
||||
from app.services.provider_router import _build_config_from_provider
|
||||
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.adapter = "text_primary"
|
||||
mock_provider.api_key = None
|
||||
@@ -234,8 +325,33 @@ class TestProviderConfigFromDB:
|
||||
mock_settings.text_model = "gemini-2.0-flash"
|
||||
|
||||
config = _build_config_from_provider(mock_provider)
|
||||
|
||||
assert config.api_key == "settings-api-key"
|
||||
|
||||
assert config.api_key == "settings-api-key"
|
||||
|
||||
def test_build_config_uses_direct_config_ref_name(self):
|
||||
"""config_ref 可以直接使用 settings 字段名,便于后台配置。"""
|
||||
from app.services.provider_router import _build_config_from_provider
|
||||
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.adapter = "antigravity"
|
||||
mock_provider.api_key = None
|
||||
mock_provider.api_base = None
|
||||
mock_provider.model = None
|
||||
mock_provider.timeout_ms = None
|
||||
mock_provider.max_retries = None
|
||||
mock_provider.config_ref = "antigravity_api_key"
|
||||
mock_provider.config_json = {}
|
||||
|
||||
with patch("app.services.provider_router.settings") as mock_settings:
|
||||
mock_settings.antigravity_api_key = "antigravity-key"
|
||||
mock_settings.antigravity_api_base = "https://antigravity.example"
|
||||
mock_settings.antigravity_model = "gemini-3-pro-image"
|
||||
|
||||
config = _build_config_from_provider(mock_provider)
|
||||
|
||||
assert config.api_key == "antigravity-key"
|
||||
assert config.api_base == "https://antigravity.example"
|
||||
assert config.model == "gemini-3-pro-image"
|
||||
|
||||
|
||||
class TestProviderCacheStartup:
|
||||
|
||||
Reference in New Issue
Block a user