refactor: separate provider capability policy
This commit is contained in:
148
backend/app/services/provider_policy.py
Normal file
148
backend/app/services/provider_policy.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""Provider capability and routing policy definitions."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Literal, Protocol, TypeAlias
|
||||
|
||||
ProviderType: TypeAlias = Literal["text", "image", "tts", "storybook"]
|
||||
|
||||
|
||||
class RoutingStrategy(str, Enum):
|
||||
"""How providers should be ordered before failover execution."""
|
||||
|
||||
PRIORITY = "priority"
|
||||
COST = "cost"
|
||||
LATENCY = "latency"
|
||||
ROUND_ROBIN = "round_robin"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CapabilityPolicy:
|
||||
"""Product-level capability policy for one provider family."""
|
||||
|
||||
capability: ProviderType
|
||||
label: str
|
||||
description: str
|
||||
settings_attr: str
|
||||
default_providers: tuple[str, ...]
|
||||
default_strategy: RoutingStrategy = RoutingStrategy.PRIORITY
|
||||
demo_provider: str | None = None
|
||||
|
||||
|
||||
class ProviderSettings(Protocol):
|
||||
"""Settings fields required by provider policy resolution."""
|
||||
|
||||
text_providers: list[str]
|
||||
image_providers: list[str]
|
||||
tts_providers: list[str]
|
||||
storybook_providers: list[str]
|
||||
enable_demo_providers: bool
|
||||
|
||||
|
||||
CAPABILITY_POLICIES: dict[ProviderType, CapabilityPolicy] = {
|
||||
"text": CapabilityPolicy(
|
||||
capability="text",
|
||||
label="文本生成",
|
||||
description="生成或润色儿童故事文本。",
|
||||
settings_attr="text_providers",
|
||||
default_providers=("gemini", "openai"),
|
||||
demo_provider="demo",
|
||||
),
|
||||
"image": CapabilityPolicy(
|
||||
capability="image",
|
||||
label="图片生成",
|
||||
description="生成故事封面或绘本插图。",
|
||||
settings_attr="image_providers",
|
||||
default_providers=("cqtai",),
|
||||
demo_provider="demo",
|
||||
),
|
||||
"tts": CapabilityPolicy(
|
||||
capability="tts",
|
||||
label="语音合成",
|
||||
description="将故事文本合成为可播放音频。",
|
||||
settings_attr="tts_providers",
|
||||
default_providers=("minimax", "elevenlabs", "edge_tts"),
|
||||
),
|
||||
"storybook": CapabilityPolicy(
|
||||
capability="storybook",
|
||||
label="绘本结构生成",
|
||||
description="生成多页绘本结构、分镜文本和插图提示词。",
|
||||
settings_attr="storybook_providers",
|
||||
default_providers=("storybook_primary",),
|
||||
demo_provider="demo",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
DEFAULT_PROVIDERS: dict[ProviderType, list[str]] = {
|
||||
capability: list(policy.default_providers)
|
||||
for capability, policy in CAPABILITY_POLICIES.items()
|
||||
}
|
||||
|
||||
|
||||
API_KEY_MAP: dict[str, str] = {
|
||||
# Text
|
||||
"gemini": "text_api_key",
|
||||
"text_primary": "text_api_key",
|
||||
"text_api_key": "text_api_key",
|
||||
"openai": "openai_api_key",
|
||||
"openai_api_key": "openai_api_key",
|
||||
# Image
|
||||
"cqtai": "cqtai_api_key",
|
||||
"cqtai_api_key": "cqtai_api_key",
|
||||
"antigravity": "antigravity_api_key",
|
||||
"antigravity_api_key": "antigravity_api_key",
|
||||
"image_primary": "image_api_key",
|
||||
"image_api_key": "image_api_key",
|
||||
# TTS
|
||||
"minimax": "minimax_api_key",
|
||||
"minimax_api_key": "minimax_api_key",
|
||||
"elevenlabs": "elevenlabs_api_key",
|
||||
"elevenlabs_api_key": "elevenlabs_api_key",
|
||||
"edge_tts": "tts_api_key",
|
||||
"tts_primary": "tts_api_key",
|
||||
"tts_api_key": "tts_api_key",
|
||||
}
|
||||
|
||||
|
||||
def get_capability_policy(capability: ProviderType) -> CapabilityPolicy:
|
||||
"""Return the product policy for a provider capability."""
|
||||
|
||||
return CAPABILITY_POLICIES[capability]
|
||||
|
||||
|
||||
def get_provider_names_from_settings(
|
||||
capability: ProviderType,
|
||||
settings: ProviderSettings,
|
||||
) -> list[str]:
|
||||
"""Resolve provider order from settings, falling back to capability defaults."""
|
||||
|
||||
policy = get_capability_policy(capability)
|
||||
configured = getattr(settings, policy.settings_attr, None)
|
||||
names = list(configured or policy.default_providers)
|
||||
|
||||
if (
|
||||
settings.enable_demo_providers
|
||||
and policy.demo_provider
|
||||
and policy.demo_provider not in names
|
||||
):
|
||||
names = [policy.demo_provider, *names]
|
||||
|
||||
return names
|
||||
|
||||
|
||||
def list_capability_policies() -> list[dict[str, object]]:
|
||||
"""Return a serializable capability policy overview for admin/docs use."""
|
||||
|
||||
return [
|
||||
{
|
||||
"capability": policy.capability,
|
||||
"label": policy.label,
|
||||
"description": policy.description,
|
||||
"settings_attr": policy.settings_attr,
|
||||
"default_providers": list(policy.default_providers),
|
||||
"default_strategy": policy.default_strategy.value,
|
||||
"demo_provider": policy.demo_provider,
|
||||
}
|
||||
for policy in CAPABILITY_POLICIES.values()
|
||||
]
|
||||
Reference in New Issue
Block a user