"""Provider capability and routing policy definitions.""" from dataclasses import dataclass from enum import Enum from typing import Literal, Protocol, TypeAlias ProviderType: TypeAlias = Literal["text", "image", "tts", "storybook"] class RoutingStrategy(str, Enum): """How providers should be ordered before failover execution.""" PRIORITY = "priority" COST = "cost" LATENCY = "latency" ROUND_ROBIN = "round_robin" @dataclass(frozen=True) class CapabilityPolicy: """Product-level capability policy for one provider family.""" capability: ProviderType label: str description: str settings_attr: str default_providers: tuple[str, ...] default_strategy: RoutingStrategy = RoutingStrategy.PRIORITY demo_provider: str | None = None class ProviderSettings(Protocol): """Settings fields required by provider policy resolution.""" text_providers: list[str] image_providers: list[str] tts_providers: list[str] storybook_providers: list[str] enable_demo_providers: bool CAPABILITY_POLICIES: dict[ProviderType, CapabilityPolicy] = { "text": CapabilityPolicy( capability="text", label="文本生成", description="生成或润色儿童故事文本。", settings_attr="text_providers", default_providers=("gemini", "openai"), demo_provider="demo", ), "image": CapabilityPolicy( capability="image", label="图片生成", description="生成故事封面或绘本插图。", settings_attr="image_providers", default_providers=("cqtai",), demo_provider="demo", ), "tts": CapabilityPolicy( capability="tts", label="语音合成", description="将故事文本合成为可播放音频。", settings_attr="tts_providers", default_providers=("minimax", "elevenlabs", "edge_tts"), ), "storybook": CapabilityPolicy( capability="storybook", label="绘本结构生成", description="生成多页绘本结构、分镜文本和插图提示词。", settings_attr="storybook_providers", default_providers=("storybook_primary",), demo_provider="demo", ), } DEFAULT_PROVIDERS: dict[ProviderType, list[str]] = { capability: list(policy.default_providers) for capability, policy in CAPABILITY_POLICIES.items() } API_KEY_MAP: dict[str, str] = { # Text "gemini": "text_api_key", "text_primary": "text_api_key", "text_api_key": "text_api_key", "openai": "openai_api_key", "openai_api_key": "openai_api_key", # Image "cqtai": "cqtai_api_key", "cqtai_api_key": "cqtai_api_key", "antigravity": "antigravity_api_key", "antigravity_api_key": "antigravity_api_key", "image_primary": "image_api_key", "image_api_key": "image_api_key", # TTS "minimax": "minimax_api_key", "minimax_api_key": "minimax_api_key", "elevenlabs": "elevenlabs_api_key", "elevenlabs_api_key": "elevenlabs_api_key", "edge_tts": "tts_api_key", "tts_primary": "tts_api_key", "tts_api_key": "tts_api_key", } def get_capability_policy(capability: ProviderType) -> CapabilityPolicy: """Return the product policy for a provider capability.""" return CAPABILITY_POLICIES[capability] def get_provider_names_from_settings( capability: ProviderType, settings: ProviderSettings, ) -> list[str]: """Resolve provider order from settings, falling back to capability defaults.""" policy = get_capability_policy(capability) configured = getattr(settings, policy.settings_attr, None) names = list(configured or policy.default_providers) if ( settings.enable_demo_providers and policy.demo_provider and policy.demo_provider not in names ): names = [policy.demo_provider, *names] return names def list_capability_policies() -> list[dict[str, object]]: """Return a serializable capability policy overview for admin/docs use.""" return [ { "capability": policy.capability, "label": policy.label, "description": policy.description, "settings_attr": policy.settings_attr, "default_providers": list(policy.default_providers), "default_strategy": policy.default_strategy.value, "demo_provider": policy.demo_provider, } for policy in CAPABILITY_POLICIES.values() ]