"""Provider routing with failover - 基于适配器注册表的智能路由。""" import time from enum import Enum from typing import TYPE_CHECKING, Literal, TypeVar from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings from app.core.logging import get_logger from app.services.adapters import AdapterConfig, AdapterRegistry from app.services.adapters.text.models import StoryOutput from app.services.cost_tracker import cost_tracker from app.services.provider_cache import get_providers from app.services.provider_metrics import health_checker, metrics_collector if TYPE_CHECKING: from app.db.admin_models import Provider logger = get_logger(__name__) T = TypeVar("T") ProviderType = Literal["text", "image", "tts", "storybook"] class RoutingStrategy(str, Enum): """路由策略枚举。""" PRIORITY = "priority" # 按优先级排序(默认) COST = "cost" # 按成本排序 LATENCY = "latency" # 按延迟排序 ROUND_ROBIN = "round_robin" # 轮询 # 默认配置映射(当 DB 无配置时使用) # 默认配置映射(当 DB 无配置时使用) # 这是“代码级”的默认策略,对应 .env 为空的情况 DEFAULT_PROVIDERS: dict[ProviderType, list[str]] = { "text": ["gemini", "openai"], "image": ["cqtai"], "tts": ["minimax", "elevenlabs", "edge_tts"], "storybook": ["gemini"], } # API Key 映射:adapter_name -> settings 属性名 API_KEY_MAP: dict[str, str] = { # Text "gemini": "text_api_key", # Gemini 还是复用 text_api_key 字段 "text_primary": "text_api_key", # 兼容旧别名 "openai": "openai_api_key", # Image "cqtai": "cqtai_api_key", "image_primary": "image_api_key", # 兼容旧别名 # TTS "minimax": "minimax_api_key", "elevenlabs": "elevenlabs_api_key", "edge_tts": "tts_api_key", # EdgeTTS 复用 tts_api_key (通常为空) "tts_primary": "tts_api_key", # 兼容旧别名 } # 轮询计数器 _round_robin_counters: dict[ProviderType, int] = { "text": 0, "image": 0, "tts": 0, } # 延迟缓存(内存中,简化实现) _latency_cache: dict[str, float] = {} def _get_api_key(config_ref: str | None, adapter_name: str) -> str: """根据 config_ref 或适配器名称获取 API Key。""" # 优先使用 config_ref key_attr = API_KEY_MAP.get(config_ref or adapter_name, None) if key_attr: return getattr(settings, key_attr, "") # 回退到适配器名称 key_attr = API_KEY_MAP.get(adapter_name, None) if key_attr: return getattr(settings, key_attr, "") return "" def _get_default_config(adapter_name: str) -> AdapterConfig | None: """获取适配器的默认配置(无 DB 记录时使用)。返回 None 表示未知适配器。""" # --- Text Defaults --- if adapter_name in ("gemini", "text_primary"): return AdapterConfig( api_key=settings.text_api_key, model=settings.text_model or "gemini-2.0-flash", timeout_ms=60000, ) if adapter_name == "openai": return AdapterConfig( api_key=getattr(settings, "openai_api_key", ""), model="gpt-4o-mini", # 这里可以从 settings 读取,看需求 timeout_ms=60000, ) # --- Image Defaults --- if adapter_name in ("cqtai"): return AdapterConfig( api_key=getattr(settings, "cqtai_api_key", ""), model="nano-banana-pro", # 默认使用 Pro timeout_ms=120000, ) if adapter_name == "image_primary": # 如果还有地方在用 image_primary,暂时映射到快或者其他 # 但既然我们全面整改,最好也删了。这里暂时保留一个空的 fallback 以防报错 return AdapterConfig( api_key=settings.image_api_key, timeout_ms=120000 ) # --- TTS Defaults --- if adapter_name == "minimax": # 传递 group_id 到 Adapter # 目前 AdapterConfig 没有 group_id 字段,我们暂时不改 Base, # 而是假设 Adapter 会从 config (通过 kwargs 或其他方式) 拿。 # 实际上我们的 MiniMaxTTSAdapter 还没有处理 group_id。 # 最简单的方法:把 group_id 藏在 api_base 里或者让 Adapter 自己去 settings 拿。 # 鉴于 _build_config_from_provider 里我们无法传递额外参数给 Adapter.__init__, # 我们这里暂时返回基础配置。 return AdapterConfig( api_key=getattr(settings, "minimax_api_key", ""), model="speech-2.6-turbo", timeout_ms=60000, ) if adapter_name == "elevenlabs": return AdapterConfig( api_key=getattr(settings, "elevenlabs_api_key", ""), timeout_ms=120000, ) if adapter_name in ("edge_tts", "tts_primary"): return AdapterConfig( api_key=settings.tts_api_key, api_base=settings.tts_api_base, model=settings.tts_model or "zh-CN-XiaoxiaoNeural", timeout_ms=120000, ) # --- Others --- if adapter_name in ("storybook_primary", "storybook_gemini"): return AdapterConfig( api_key=settings.text_api_key, # 复用 Gemini key model=settings.text_model, timeout_ms=120000, ) # 未知适配器返回 None return None def _build_config_from_provider(provider: "Provider") -> AdapterConfig: """从 DB Provider 记录构建 AdapterConfig。""" api_key = getattr(provider, "api_key", None) or "" if not api_key: api_key = _get_api_key(provider.config_ref, provider.adapter) default = _get_default_config(provider.adapter) if default is None: default = AdapterConfig(api_key="", timeout_ms=60000) return AdapterConfig( api_key=api_key or default.api_key, api_base=provider.api_base or default.api_base, model=provider.model or default.model, timeout_ms=provider.timeout_ms or default.timeout_ms, max_retries=provider.max_retries or default.max_retries, extra_config=provider.config_json or {}, ) async def _get_providers_with_config( provider_type: ProviderType, ) -> list[tuple[str, AdapterConfig, "Provider | None"]]: """获取供应商列表及其配置。 Returns: [(adapter_name, config, provider_or_none), ...] 按优先级排序 """ db_providers = await get_providers(provider_type) if db_providers: return [(p.adapter, _build_config_from_provider(p), p) for p in db_providers] settings_map = { "text": settings.text_providers, "image": settings.image_providers, "tts": settings.tts_providers, } names = settings_map.get(provider_type) or DEFAULT_PROVIDERS[provider_type] result = [] for name in names: config = _get_default_config(name) if config is None: logger.warning("unknown_adapter_skipped", adapter=name, provider_type=provider_type) continue result.append((name, config, None)) return result def _sort_by_strategy( providers: list[tuple[str, AdapterConfig, "Provider | None"]], strategy: RoutingStrategy, provider_type: ProviderType, ) -> list[tuple[str, AdapterConfig, "Provider | None"]]: """按策略排序供应商列表。""" if strategy == RoutingStrategy.PRIORITY: # 按 priority 降序, weight 降序 return sorted( providers, key=lambda x: (-(x[2].priority if x[2] else 0), -(x[2].weight if x[2] else 1)), ) if strategy == RoutingStrategy.COST: # 按预估成本升序 def get_cost(item: tuple[str, AdapterConfig, "Provider | None"]) -> float: adapter_class = AdapterRegistry.get(provider_type, item[0]) if adapter_class: try: adapter = adapter_class(item[1]) return adapter.estimated_cost except Exception: pass return float("inf") return sorted(providers, key=get_cost) if strategy == RoutingStrategy.LATENCY: # 按历史延迟升序 def get_latency(item: tuple[str, AdapterConfig, "Provider | None"]) -> float: return _latency_cache.get(item[0], float("inf")) return sorted(providers, key=get_latency) if strategy == RoutingStrategy.ROUND_ROBIN: # 轮询:旋转列表 counter = _round_robin_counters[provider_type] _round_robin_counters[provider_type] = (counter + 1) % max(len(providers), 1) return providers[counter:] + providers[:counter] return providers async def _route_with_failover( provider_type: ProviderType, strategy: RoutingStrategy = RoutingStrategy.PRIORITY, db: AsyncSession | None = None, user_id: str | None = None, **kwargs, ) -> T: """通用 provider failover 路由。 Args: provider_type: 供应商类型 (text/image/tts/storybook) strategy: 路由策略 db: 数据库会话(可选,用于指标收集和熔断检查) user_id: 用户 ID(可选,用于成本追踪和预算检查) **kwargs: 传递给适配器的参数 """ providers = await _get_providers_with_config(provider_type) if not providers: raise ValueError(f"No {provider_type} providers configured.") # 按策略排序 sorted_providers = _sort_by_strategy(providers, strategy, provider_type) # 如果有 db 会话,过滤掉熔断的供应商 if db: healthy_providers = [] for item in sorted_providers: name, config, db_provider = item provider_id = db_provider.id if db_provider else name if await health_checker.is_healthy(db, provider_id): healthy_providers.append(item) else: logger.debug("provider_circuit_open", adapter=name, provider_id=provider_id) # 如果所有供应商都熔断,仍然尝试第一个(允许恢复) if not healthy_providers: healthy_providers = sorted_providers[:1] sorted_providers = healthy_providers errors: list[str] = [] for name, config, db_provider in sorted_providers: adapter_class = AdapterRegistry.get(provider_type, name) if not adapter_class: errors.append(f"{name}: 适配器未注册") continue provider_id = db_provider.id if db_provider else name try: logger.debug( "provider_attempt", provider_type=provider_type, adapter=name, strategy=strategy.value, ) adapter = adapter_class(config) # 执行并计时 start_time = time.time() result = await adapter.execute(**kwargs) latency_ms = int((time.time() - start_time) * 1000) # 更新延迟缓存 _latency_cache[name] = latency_ms # 记录成功指标 if db: await metrics_collector.record_call( db, provider_id=provider_id, success=True, latency_ms=latency_ms, cost_usd=adapter.estimated_cost, ) await health_checker.record_call_result(db, provider_id, success=True) # 记录用户成本 if user_id: await cost_tracker.record_cost( db, user_id=user_id, provider_name=name, capability=provider_type, estimated_cost=adapter.estimated_cost, provider_id=provider_id if db_provider else None, ) logger.info( "provider_success", provider_type=provider_type, adapter=name, latency_ms=latency_ms, ) return result except Exception as exc: error_msg = str(exc) logger.warning( "provider_failed", provider_type=provider_type, adapter=name, error=error_msg, ) errors.append(f"{name}: {exc}") # 记录失败指标 if db: await metrics_collector.record_call( db, provider_id=provider_id, success=False, error_message=error_msg, ) await health_checker.record_call_result( db, provider_id, success=False, error=error_msg ) raise ValueError(f"No {provider_type} provider succeeded. Errors: {' | '.join(errors)}") async def generate_story_content( input_type: Literal["keywords", "full_story"], data: str, education_theme: str | None = None, memory_context: str | None = None, strategy: RoutingStrategy = RoutingStrategy.PRIORITY, db: AsyncSession | None = None, ) -> StoryOutput: """生成或润色故事,支持 failover。""" return await _route_with_failover( "text", strategy=strategy, db=db, input_type=input_type, data=data, education_theme=education_theme, memory_context=memory_context, ) async def generate_image( prompt: str, strategy: RoutingStrategy = RoutingStrategy.PRIORITY, db: AsyncSession | None = None, **kwargs, ) -> str: """生成图片,返回 URL,支持 failover。""" return await _route_with_failover("image", strategy=strategy, db=db, prompt=prompt, **kwargs) async def text_to_speech( text: str, strategy: RoutingStrategy = RoutingStrategy.PRIORITY, db: AsyncSession | None = None, ) -> bytes: """文本转语音,返回 MP3 bytes,支持 failover。""" return await _route_with_failover("tts", strategy=strategy, db=db, text=text) async def generate_storybook( keywords: str, page_count: int = 6, education_theme: str | None = None, memory_context: str | None = None, strategy: RoutingStrategy = RoutingStrategy.PRIORITY, db: AsyncSession | None = None, ): """生成分页故事书,支持 failover。""" from app.services.adapters.storybook.primary import Storybook result: Storybook = await _route_with_failover( "storybook", strategy=strategy, db=db, keywords=keywords, page_count=page_count, education_theme=education_theme, memory_context=memory_context, ) return result