"""Provider routing with failover - 基于适配器注册表的智能路由。""" import time from typing import TYPE_CHECKING, Literal, TypeVar from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings from app.core.logging import get_logger from app.services.adapters import AdapterConfig, AdapterRegistry from app.services.adapters.text.models import StoryOutput from app.services.cost_tracker import cost_tracker from app.services.generation_jobs import record_generation_event from app.services.provider_cache import get_providers from app.services.provider_metrics import health_checker, metrics_collector from app.services.provider_policy import ( API_KEY_MAP, DEFAULT_PROVIDERS, ProviderType, RoutingStrategy, get_provider_names_from_settings, ) if TYPE_CHECKING: from app.db.admin_models import Provider from app.db.models import GenerationJob logger = get_logger(__name__) T = TypeVar("T") # 轮询计数器 _round_robin_counters: dict[ProviderType, int] = { provider_type: 0 for provider_type in DEFAULT_PROVIDERS } # 延迟缓存(内存中,简化实现) _latency_cache: dict[str, float] = {} def _safe_estimated_cost(adapter) -> float: """Return an adapter cost value that is safe to serialize in job events.""" try: return float(adapter.estimated_cost) except Exception: return 0.0 async def _record_provider_event_if_present( db: AsyncSession | None, *, job: "GenerationJob | None", event_type: str, status: str, provider_type: ProviderType, adapter_name: str, strategy: RoutingStrategy, provider_id: str | None = None, story_id: int | None = None, latency_ms: int | None = None, estimated_cost: float | None = None, error: str | None = None, ) -> None: """Append provider call telemetry to the active generation job.""" if db is None or job is None: return await record_generation_event( db, job=job, story_id=story_id, event_type=event_type, status=status, message=( f"{provider_type} provider {adapter_name} {status}." if error is None else f"{provider_type} provider {adapter_name} failed." ), metadata={ "capability": provider_type, "adapter": adapter_name, "provider_id": provider_id, "strategy": strategy.value, "latency_ms": latency_ms, "estimated_cost_usd": estimated_cost, "error": error, }, ) def _get_api_key(config_ref: str | None, adapter_name: str) -> str: """根据 config_ref 或适配器名称获取 API Key。""" # 优先使用 config_ref key_attr = API_KEY_MAP.get(config_ref or adapter_name, None) if key_attr: return getattr(settings, key_attr, "") # 回退到适配器名称 key_attr = API_KEY_MAP.get(adapter_name, None) if key_attr: return getattr(settings, key_attr, "") return "" def _get_default_config(adapter_name: str) -> AdapterConfig | None: """获取适配器的默认配置(无 DB 记录时使用)。返回 None 表示未知适配器。""" if adapter_name == "demo": return AdapterConfig( api_key="", model="demo", timeout_ms=1000, ) # --- ASR Defaults --- if adapter_name == "openai_asr": return AdapterConfig( api_key=settings.openai_api_key, model=settings.voice_transcription_model, timeout_ms=60000, ) # --- Text Defaults --- if adapter_name in ("gemini", "text_primary"): return AdapterConfig( api_key=settings.text_api_key, model=settings.text_model or "gemini-2.0-flash", timeout_ms=60000, ) if adapter_name == "openai": return AdapterConfig( api_key=getattr(settings, "openai_api_key", ""), model=settings.openai_model, timeout_ms=60000, ) # --- Image Defaults --- if adapter_name == "cqtai": return AdapterConfig( api_key=getattr(settings, "cqtai_api_key", ""), model=settings.image_model or "nano-banana-pro", timeout_ms=120000, ) if adapter_name == "antigravity": return AdapterConfig( api_key=getattr(settings, "antigravity_api_key", ""), api_base=getattr(settings, "antigravity_api_base", ""), model=settings.antigravity_model, timeout_ms=120000, ) if adapter_name == "image_primary": # 如果还有地方在用 image_primary,暂时映射到快或者其他 # 但既然我们全面整改,最好也删了。这里暂时保留一个空的 fallback 以防报错 return AdapterConfig( api_key=settings.image_api_key, timeout_ms=120000 ) # --- TTS Defaults --- if adapter_name == "minimax": # 传递 group_id 到 Adapter # 目前 AdapterConfig 没有 group_id 字段,我们暂时不改 Base, # 而是假设 Adapter 会从 config (通过 kwargs 或其他方式) 拿。 # 实际上我们的 MiniMaxTTSAdapter 还没有处理 group_id。 # 最简单的方法:把 group_id 藏在 api_base 里或者让 Adapter 自己去 settings 拿。 # 鉴于 _build_config_from_provider 里我们无法传递额外参数给 Adapter.__init__, # 我们这里暂时返回基础配置。 return AdapterConfig( api_key=getattr(settings, "minimax_api_key", ""), model=settings.tts_minimax_model, timeout_ms=60000, ) if adapter_name == "elevenlabs": return AdapterConfig( api_key=getattr(settings, "elevenlabs_api_key", ""), timeout_ms=120000, ) if adapter_name in ("edge_tts", "tts_primary"): return AdapterConfig( api_key=settings.tts_api_key, api_base=settings.tts_api_base, model=settings.tts_model or "zh-CN-XiaoxiaoNeural", timeout_ms=120000, ) # --- Others --- if adapter_name in ("storybook_primary", "storybook_gemini"): return AdapterConfig( api_key=settings.text_api_key, # 复用 Gemini key model=settings.text_model, timeout_ms=120000, ) # 未知适配器返回 None return None def _build_config_from_provider(provider: "Provider") -> AdapterConfig: """从 DB Provider 记录构建 AdapterConfig。""" api_key = getattr(provider, "api_key", None) or "" if not api_key: api_key = _get_api_key(provider.config_ref, provider.adapter) default = _get_default_config(provider.adapter) if default is None: default = AdapterConfig(api_key="", timeout_ms=60000) return AdapterConfig( api_key=api_key or default.api_key, api_base=provider.api_base or default.api_base, model=provider.model or default.model, timeout_ms=provider.timeout_ms or default.timeout_ms, max_retries=provider.max_retries or default.max_retries, extra_config=provider.config_json or {}, ) async def _get_providers_with_config( provider_type: ProviderType, ) -> list[tuple[str, AdapterConfig, "Provider | None"]]: """获取供应商列表及其配置。 Returns: [(adapter_name, config, provider_or_none), ...] 按优先级排序 """ db_providers = await get_providers(provider_type) if db_providers: return [(p.adapter, _build_config_from_provider(p), p) for p in db_providers] names = get_provider_names_from_settings(provider_type, settings) result = [] for name in names: config = _get_default_config(name) if config is None: logger.warning("unknown_adapter_skipped", adapter=name, provider_type=provider_type) continue result.append((name, config, None)) return result def _sort_by_strategy( providers: list[tuple[str, AdapterConfig, "Provider | None"]], strategy: RoutingStrategy, provider_type: ProviderType, ) -> list[tuple[str, AdapterConfig, "Provider | None"]]: """按策略排序供应商列表。""" if strategy == RoutingStrategy.PRIORITY: # 按 priority 降序, weight 降序 return sorted( providers, key=lambda x: (-(x[2].priority if x[2] else 0), -(x[2].weight if x[2] else 1)), ) if strategy == RoutingStrategy.COST: # 按预估成本升序 def get_cost(item: tuple[str, AdapterConfig, "Provider | None"]) -> float: adapter_class = AdapterRegistry.get(provider_type, item[0]) if adapter_class: try: adapter = adapter_class(item[1]) return adapter.estimated_cost except Exception: pass return float("inf") return sorted(providers, key=get_cost) if strategy == RoutingStrategy.LATENCY: # 按历史延迟升序 def get_latency(item: tuple[str, AdapterConfig, "Provider | None"]) -> float: return _latency_cache.get(item[0], float("inf")) return sorted(providers, key=get_latency) if strategy == RoutingStrategy.ROUND_ROBIN: # 轮询:旋转列表 counter = _round_robin_counters[provider_type] _round_robin_counters[provider_type] = (counter + 1) % max(len(providers), 1) return providers[counter:] + providers[:counter] return providers async def _route_with_failover( provider_type: ProviderType, strategy: RoutingStrategy = RoutingStrategy.PRIORITY, db: AsyncSession | None = None, user_id: str | None = None, generation_job: "GenerationJob | None" = None, story_id: int | None = None, **kwargs, ) -> T: """通用 provider failover 路由。 Args: provider_type: 供应商类型 (text/image/tts/storybook/asr) strategy: 路由策略 db: 数据库会话(可选,用于指标收集和熔断检查) user_id: 用户 ID(可选,用于成本追踪和预算检查) generation_job: 生成任务(可选,用于记录 provider 调用轨迹) story_id: 故事 ID(可选,用于关联 provider 事件) **kwargs: 传递给适配器的参数 """ provider_names = kwargs.pop("provider_names", None) if provider_names: providers = [ (name, _get_default_config(name) or AdapterConfig(api_key=""), None) for name in provider_names ] else: providers = await _get_providers_with_config(provider_type) if not providers: raise ValueError(f"No {provider_type} providers configured.") # 按策略排序 sorted_providers = _sort_by_strategy(providers, strategy, provider_type) # 如果有 db 会话,过滤掉后台管理台中已熔断的供应商。 # .env/default provider 没有 providers 表记录,不能写入带外键的健康表。 if db: healthy_providers = [] for item in sorted_providers: name, config, db_provider = item if db_provider is None: healthy_providers.append(item) continue provider_id = db_provider.id if await health_checker.is_healthy(db, provider_id): healthy_providers.append(item) else: logger.debug("provider_circuit_open", adapter=name, provider_id=provider_id) # 如果所有供应商都熔断,仍然尝试第一个(允许恢复) if not healthy_providers: healthy_providers = sorted_providers[:1] sorted_providers = healthy_providers errors: list[str] = [] for name, config, db_provider in sorted_providers: adapter_class = AdapterRegistry.get(provider_type, name) if not adapter_class: errors.append(f"{name}: 适配器未注册") continue provider_id = str(db_provider.id) if db_provider else None estimated_cost: float | None = None start_time: float | None = None try: logger.debug( "provider_attempt", provider_type=provider_type, adapter=name, strategy=strategy.value, ) adapter = adapter_class(config) estimated_cost = _safe_estimated_cost(adapter) await _record_provider_event_if_present( db, job=generation_job, story_id=story_id, event_type="provider_call_started", status="running", provider_type=provider_type, adapter_name=name, provider_id=provider_id, strategy=strategy, estimated_cost=estimated_cost, ) # 执行并计时 start_time = time.time() result = await adapter.execute(**kwargs) latency_ms = int((time.time() - start_time) * 1000) # 更新延迟缓存 _latency_cache[name] = latency_ms # 记录成功指标。Provider 指标/健康表带外键,只记录后台管理台里的真实 provider。 if db and db_provider and provider_id: await metrics_collector.record_call( db, provider_id=provider_id, success=True, latency_ms=latency_ms, cost_usd=estimated_cost, ) await health_checker.record_call_result(db, provider_id, success=True) # 记录用户成本;环境变量/default provider 没有 provider_id,保留 provider_name 即可。 if db and user_id: await cost_tracker.record_cost( db, user_id=user_id, provider_name=name, capability=provider_type, estimated_cost=estimated_cost, provider_id=provider_id, ) await _record_provider_event_if_present( db, job=generation_job, story_id=story_id, event_type="provider_call_succeeded", status="succeeded", provider_type=provider_type, adapter_name=name, provider_id=provider_id, strategy=strategy, latency_ms=latency_ms, estimated_cost=estimated_cost, ) logger.info( "provider_success", provider_type=provider_type, adapter=name, latency_ms=latency_ms, ) return result except Exception as exc: error_msg = str(exc) latency_ms = ( int((time.time() - start_time) * 1000) if start_time is not None else None ) logger.warning( "provider_failed", provider_type=provider_type, adapter=name, error=error_msg, ) errors.append(f"{name}: {exc}") # 记录失败指标。Provider 指标/健康表带外键,只记录后台管理台里的真实 provider。 if db and db_provider and provider_id: await metrics_collector.record_call( db, provider_id=provider_id, success=False, error_message=error_msg, ) await health_checker.record_call_result( db, provider_id, success=False, error=error_msg ) await _record_provider_event_if_present( db, job=generation_job, story_id=story_id, event_type="provider_call_failed", status="failed", provider_type=provider_type, adapter_name=name, provider_id=provider_id, strategy=strategy, latency_ms=latency_ms, estimated_cost=estimated_cost, error=error_msg, ) raise ValueError(f"No {provider_type} provider succeeded. Errors: {' | '.join(errors)}") async def transcribe_audio( audio_bytes: bytes, file_name: str | None = None, mime_type: str | None = None, transcript_hint: str | None = None, language: str | None = None, provider_names: list[str] | None = None, strategy: RoutingStrategy = RoutingStrategy.PRIORITY, db: AsyncSession | None = None, user_id: str | None = None, ): """语音转写,支持 provider failover。""" from app.services.adapters.asr.models import TranscriptionOutput result: TranscriptionOutput = await _route_with_failover( "asr", strategy=strategy, db=db, user_id=user_id, audio_bytes=audio_bytes, file_name=file_name, mime_type=mime_type, transcript_hint=transcript_hint, language=language, provider_names=provider_names, ) return result async def generate_story_content( input_type: Literal["keywords", "full_story"], data: str, education_theme: str | None = None, memory_context: str | None = None, strategy: RoutingStrategy = RoutingStrategy.PRIORITY, db: AsyncSession | None = None, user_id: str | None = None, generation_job: "GenerationJob | None" = None, ) -> StoryOutput: """生成或润色故事,支持 failover。""" return await _route_with_failover( "text", strategy=strategy, db=db, user_id=user_id, generation_job=generation_job, input_type=input_type, data=data, education_theme=education_theme, memory_context=memory_context, ) async def generate_image( prompt: str, strategy: RoutingStrategy = RoutingStrategy.PRIORITY, db: AsyncSession | None = None, user_id: str | None = None, generation_job: "GenerationJob | None" = None, story_id: int | None = None, **kwargs, ) -> str: """生成图片,返回 URL,支持 failover。""" return await _route_with_failover( "image", strategy=strategy, db=db, user_id=user_id, generation_job=generation_job, story_id=story_id, prompt=prompt, **kwargs, ) async def text_to_speech( text: str, strategy: RoutingStrategy = RoutingStrategy.PRIORITY, db: AsyncSession | None = None, user_id: str | None = None, generation_job: "GenerationJob | None" = None, story_id: int | None = None, ) -> bytes: """文本转语音,返回 MP3 bytes,支持 failover。""" return await _route_with_failover( "tts", strategy=strategy, db=db, user_id=user_id, generation_job=generation_job, story_id=story_id, text=text, ) async def generate_storybook( keywords: str, page_count: int = 6, education_theme: str | None = None, memory_context: str | None = None, strategy: RoutingStrategy = RoutingStrategy.PRIORITY, db: AsyncSession | None = None, user_id: str | None = None, generation_job: "GenerationJob | None" = None, ): """生成分页故事书,支持 failover。""" from app.services.adapters.storybook.primary import Storybook result: Storybook = await _route_with_failover( "storybook", strategy=strategy, db=db, user_id=user_id, generation_job=generation_job, keywords=keywords, page_count=page_count, education_theme=education_theme, memory_context=memory_context, ) return result