refactor: separate provider capability policy

This commit is contained in:
2026-04-18 13:37:59 +08:00
parent 0444b81df6
commit 7b8e7c9944
11 changed files with 393 additions and 88 deletions

View File

@@ -8,7 +8,7 @@ from app.db.admin_models import Provider
from app.db.database import get_db
from app.services.adapters.registry import AdapterRegistry
from app.services.cost_tracker import cost_tracker
from app.services.provider_router import DEFAULT_PROVIDERS
from app.services.provider_policy import DEFAULT_PROVIDERS, list_capability_policies
from app.services.secret_service import SecretService
router = APIRouter(dependencies=[Depends(admin_guard)])
@@ -68,6 +68,12 @@ async def get_env_defaults():
return DEFAULT_PROVIDERS
@router.get("/providers/capabilities")
async def list_provider_capabilities():
"""获取 Provider 能力分层与默认路由策略。"""
return list_capability_policies()
@router.get("/providers", response_model=list[ProviderResponse])
async def list_providers(db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Provider))

View File

@@ -1,22 +1,20 @@
"""Redis-backed cache for providers loaded from DB."""
import json
from collections import defaultdict
from typing import Literal
from pydantic import BaseModel
import json
from collections import defaultdict
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.core.redis import get_redis
from app.db.admin_models import Provider
logger = get_logger(__name__)
ProviderType = Literal["text", "image", "tts", "storybook"]
from app.core.redis import get_redis
from app.db.admin_models import Provider
from app.services.provider_policy import ProviderType
logger = get_logger(__name__)
class CachedProvider(BaseModel):
"""Serializable provider configuration matching DB model fields."""

View File

@@ -0,0 +1,148 @@
"""Provider capability and routing policy definitions."""
from dataclasses import dataclass
from enum import Enum
from typing import Literal, Protocol, TypeAlias
ProviderType: TypeAlias = Literal["text", "image", "tts", "storybook"]
class RoutingStrategy(str, Enum):
"""How providers should be ordered before failover execution."""
PRIORITY = "priority"
COST = "cost"
LATENCY = "latency"
ROUND_ROBIN = "round_robin"
@dataclass(frozen=True)
class CapabilityPolicy:
"""Product-level capability policy for one provider family."""
capability: ProviderType
label: str
description: str
settings_attr: str
default_providers: tuple[str, ...]
default_strategy: RoutingStrategy = RoutingStrategy.PRIORITY
demo_provider: str | None = None
class ProviderSettings(Protocol):
"""Settings fields required by provider policy resolution."""
text_providers: list[str]
image_providers: list[str]
tts_providers: list[str]
storybook_providers: list[str]
enable_demo_providers: bool
CAPABILITY_POLICIES: dict[ProviderType, CapabilityPolicy] = {
"text": CapabilityPolicy(
capability="text",
label="文本生成",
description="生成或润色儿童故事文本。",
settings_attr="text_providers",
default_providers=("gemini", "openai"),
demo_provider="demo",
),
"image": CapabilityPolicy(
capability="image",
label="图片生成",
description="生成故事封面或绘本插图。",
settings_attr="image_providers",
default_providers=("cqtai",),
demo_provider="demo",
),
"tts": CapabilityPolicy(
capability="tts",
label="语音合成",
description="将故事文本合成为可播放音频。",
settings_attr="tts_providers",
default_providers=("minimax", "elevenlabs", "edge_tts"),
),
"storybook": CapabilityPolicy(
capability="storybook",
label="绘本结构生成",
description="生成多页绘本结构、分镜文本和插图提示词。",
settings_attr="storybook_providers",
default_providers=("storybook_primary",),
demo_provider="demo",
),
}
DEFAULT_PROVIDERS: dict[ProviderType, list[str]] = {
capability: list(policy.default_providers)
for capability, policy in CAPABILITY_POLICIES.items()
}
API_KEY_MAP: dict[str, str] = {
# Text
"gemini": "text_api_key",
"text_primary": "text_api_key",
"text_api_key": "text_api_key",
"openai": "openai_api_key",
"openai_api_key": "openai_api_key",
# Image
"cqtai": "cqtai_api_key",
"cqtai_api_key": "cqtai_api_key",
"antigravity": "antigravity_api_key",
"antigravity_api_key": "antigravity_api_key",
"image_primary": "image_api_key",
"image_api_key": "image_api_key",
# TTS
"minimax": "minimax_api_key",
"minimax_api_key": "minimax_api_key",
"elevenlabs": "elevenlabs_api_key",
"elevenlabs_api_key": "elevenlabs_api_key",
"edge_tts": "tts_api_key",
"tts_primary": "tts_api_key",
"tts_api_key": "tts_api_key",
}
def get_capability_policy(capability: ProviderType) -> CapabilityPolicy:
"""Return the product policy for a provider capability."""
return CAPABILITY_POLICIES[capability]
def get_provider_names_from_settings(
capability: ProviderType,
settings: ProviderSettings,
) -> list[str]:
"""Resolve provider order from settings, falling back to capability defaults."""
policy = get_capability_policy(capability)
configured = getattr(settings, policy.settings_attr, None)
names = list(configured or policy.default_providers)
if (
settings.enable_demo_providers
and policy.demo_provider
and policy.demo_provider not in names
):
names = [policy.demo_provider, *names]
return names
def list_capability_policies() -> list[dict[str, object]]:
"""Return a serializable capability policy overview for admin/docs use."""
return [
{
"capability": policy.capability,
"label": policy.label,
"description": policy.description,
"settings_attr": policy.settings_attr,
"default_providers": list(policy.default_providers),
"default_strategy": policy.default_strategy.value,
"demo_provider": policy.demo_provider,
}
for policy in CAPABILITY_POLICIES.values()
]

View File

@@ -1,7 +1,6 @@
"""Provider routing with failover - 基于适配器注册表的智能路由。"""
import time
from enum import Enum
from typing import TYPE_CHECKING, Literal, TypeVar
from sqlalchemy.ext.asyncio import AsyncSession
@@ -13,6 +12,13 @@ from app.services.adapters.text.models import StoryOutput
from app.services.cost_tracker import cost_tracker
from app.services.provider_cache import get_providers
from app.services.provider_metrics import health_checker, metrics_collector
from app.services.provider_policy import (
API_KEY_MAP,
DEFAULT_PROVIDERS,
ProviderType,
RoutingStrategy,
get_provider_names_from_settings,
)
if TYPE_CHECKING:
from app.db.admin_models import Provider
@@ -21,50 +27,9 @@ logger = get_logger(__name__)
T = TypeVar("T")
ProviderType = Literal["text", "image", "tts", "storybook"]
class RoutingStrategy(str, Enum):
"""路由策略枚举。"""
PRIORITY = "priority" # 按优先级排序(默认)
COST = "cost" # 按成本排序
LATENCY = "latency" # 按延迟排序
ROUND_ROBIN = "round_robin" # 轮询
# 默认配置映射(当 DB 无配置时使用)
# 这是“代码级”的默认策略,对应 .env 为空的情况
DEFAULT_PROVIDERS: dict[ProviderType, list[str]] = {
"text": ["gemini", "openai"],
"image": ["cqtai"],
"tts": ["minimax", "elevenlabs", "edge_tts"],
"storybook": ["storybook_primary"],
}
# API Key 映射adapter_name -> settings 属性名
API_KEY_MAP: dict[str, str] = {
# Text
"gemini": "text_api_key", # Gemini 还是复用 text_api_key 字段
"text_primary": "text_api_key", # 兼容旧别名
"openai": "openai_api_key",
# Image
"cqtai": "cqtai_api_key",
"image_primary": "image_api_key", # 兼容旧别名
# TTS
"minimax": "minimax_api_key",
"elevenlabs": "elevenlabs_api_key",
"edge_tts": "tts_api_key", # EdgeTTS 复用 tts_api_key (通常为空)
"tts_primary": "tts_api_key", # 兼容旧别名
}
# 轮询计数器
_round_robin_counters: dict[ProviderType, int] = {
"text": 0,
"image": 0,
"tts": 0,
provider_type: 0 for provider_type in DEFAULT_PROVIDERS
}
# 延迟缓存(内存中,简化实现)
@@ -115,6 +80,13 @@ def _get_default_config(adapter_name: str) -> AdapterConfig | None:
model=settings.image_model or "nano-banana-pro",
timeout_ms=120000,
)
if adapter_name == "antigravity":
return AdapterConfig(
api_key=getattr(settings, "antigravity_api_key", ""),
api_base=getattr(settings, "antigravity_api_base", ""),
model=settings.antigravity_model,
timeout_ms=120000,
)
if adapter_name == "image_primary":
# 如果还有地方在用 image_primary暂时映射到快或者其他
# 但既然我们全面整改,最好也删了。这里暂时保留一个空的 fallback 以防报错
@@ -196,15 +168,7 @@ async def _get_providers_with_config(
if db_providers:
return [(p.adapter, _build_config_from_provider(p), p) for p in db_providers]
settings_map = {
"text": settings.text_providers,
"image": settings.image_providers,
"tts": settings.tts_providers,
"storybook": settings.storybook_providers,
}
names = settings_map.get(provider_type) or DEFAULT_PROVIDERS[provider_type]
if settings.enable_demo_providers and "demo" not in names:
names = ["demo", *names]
names = get_provider_names_from_settings(provider_type, settings)
result = []
for name in names: