refactor: separate provider capability policy

This commit is contained in:
2026-04-18 13:37:59 +08:00
parent 0444b81df6
commit 7b8e7c9944
11 changed files with 393 additions and 88 deletions

View File

@@ -0,0 +1,148 @@
"""Provider capability and routing policy definitions."""
from dataclasses import dataclass
from enum import Enum
from typing import Literal, Protocol, TypeAlias
ProviderType: TypeAlias = Literal["text", "image", "tts", "storybook"]
class RoutingStrategy(str, Enum):
"""How providers should be ordered before failover execution."""
PRIORITY = "priority"
COST = "cost"
LATENCY = "latency"
ROUND_ROBIN = "round_robin"
@dataclass(frozen=True)
class CapabilityPolicy:
"""Product-level capability policy for one provider family."""
capability: ProviderType
label: str
description: str
settings_attr: str
default_providers: tuple[str, ...]
default_strategy: RoutingStrategy = RoutingStrategy.PRIORITY
demo_provider: str | None = None
class ProviderSettings(Protocol):
"""Settings fields required by provider policy resolution."""
text_providers: list[str]
image_providers: list[str]
tts_providers: list[str]
storybook_providers: list[str]
enable_demo_providers: bool
CAPABILITY_POLICIES: dict[ProviderType, CapabilityPolicy] = {
"text": CapabilityPolicy(
capability="text",
label="文本生成",
description="生成或润色儿童故事文本。",
settings_attr="text_providers",
default_providers=("gemini", "openai"),
demo_provider="demo",
),
"image": CapabilityPolicy(
capability="image",
label="图片生成",
description="生成故事封面或绘本插图。",
settings_attr="image_providers",
default_providers=("cqtai",),
demo_provider="demo",
),
"tts": CapabilityPolicy(
capability="tts",
label="语音合成",
description="将故事文本合成为可播放音频。",
settings_attr="tts_providers",
default_providers=("minimax", "elevenlabs", "edge_tts"),
),
"storybook": CapabilityPolicy(
capability="storybook",
label="绘本结构生成",
description="生成多页绘本结构、分镜文本和插图提示词。",
settings_attr="storybook_providers",
default_providers=("storybook_primary",),
demo_provider="demo",
),
}
DEFAULT_PROVIDERS: dict[ProviderType, list[str]] = {
capability: list(policy.default_providers)
for capability, policy in CAPABILITY_POLICIES.items()
}
API_KEY_MAP: dict[str, str] = {
# Text
"gemini": "text_api_key",
"text_primary": "text_api_key",
"text_api_key": "text_api_key",
"openai": "openai_api_key",
"openai_api_key": "openai_api_key",
# Image
"cqtai": "cqtai_api_key",
"cqtai_api_key": "cqtai_api_key",
"antigravity": "antigravity_api_key",
"antigravity_api_key": "antigravity_api_key",
"image_primary": "image_api_key",
"image_api_key": "image_api_key",
# TTS
"minimax": "minimax_api_key",
"minimax_api_key": "minimax_api_key",
"elevenlabs": "elevenlabs_api_key",
"elevenlabs_api_key": "elevenlabs_api_key",
"edge_tts": "tts_api_key",
"tts_primary": "tts_api_key",
"tts_api_key": "tts_api_key",
}
def get_capability_policy(capability: ProviderType) -> CapabilityPolicy:
"""Return the product policy for a provider capability."""
return CAPABILITY_POLICIES[capability]
def get_provider_names_from_settings(
capability: ProviderType,
settings: ProviderSettings,
) -> list[str]:
"""Resolve provider order from settings, falling back to capability defaults."""
policy = get_capability_policy(capability)
configured = getattr(settings, policy.settings_attr, None)
names = list(configured or policy.default_providers)
if (
settings.enable_demo_providers
and policy.demo_provider
and policy.demo_provider not in names
):
names = [policy.demo_provider, *names]
return names
def list_capability_policies() -> list[dict[str, object]]:
"""Return a serializable capability policy overview for admin/docs use."""
return [
{
"capability": policy.capability,
"label": policy.label,
"description": policy.description,
"settings_attr": policy.settings_attr,
"default_providers": list(policy.default_providers),
"default_strategy": policy.default_strategy.value,
"demo_provider": policy.demo_provider,
}
for policy in CAPABILITY_POLICIES.values()
]