433 lines
14 KiB
Python
433 lines
14 KiB
Python
"""Provider routing with failover - 基于适配器注册表的智能路由。"""
|
||
|
||
import time
|
||
from enum import Enum
|
||
from typing import TYPE_CHECKING, Literal, TypeVar
|
||
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.core.config import settings
|
||
from app.core.logging import get_logger
|
||
from app.services.adapters import AdapterConfig, AdapterRegistry
|
||
from app.services.adapters.text.models import StoryOutput
|
||
from app.services.cost_tracker import cost_tracker
|
||
from app.services.provider_cache import get_providers
|
||
from app.services.provider_metrics import health_checker, metrics_collector
|
||
|
||
if TYPE_CHECKING:
|
||
from app.db.admin_models import Provider
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
T = TypeVar("T")
|
||
|
||
ProviderType = Literal["text", "image", "tts", "storybook"]
|
||
|
||
|
||
class RoutingStrategy(str, Enum):
|
||
"""路由策略枚举。"""
|
||
|
||
PRIORITY = "priority" # 按优先级排序(默认)
|
||
COST = "cost" # 按成本排序
|
||
LATENCY = "latency" # 按延迟排序
|
||
ROUND_ROBIN = "round_robin" # 轮询
|
||
|
||
|
||
# 默认配置映射(当 DB 无配置时使用)
|
||
# 默认配置映射(当 DB 无配置时使用)
|
||
# 这是“代码级”的默认策略,对应 .env 为空的情况
|
||
DEFAULT_PROVIDERS: dict[ProviderType, list[str]] = {
|
||
"text": ["gemini", "openai"],
|
||
"image": ["cqtai"],
|
||
"tts": ["minimax", "elevenlabs", "edge_tts"],
|
||
"storybook": ["gemini"],
|
||
}
|
||
|
||
# API Key 映射:adapter_name -> settings 属性名
|
||
API_KEY_MAP: dict[str, str] = {
|
||
# Text
|
||
"gemini": "text_api_key", # Gemini 还是复用 text_api_key 字段
|
||
"text_primary": "text_api_key", # 兼容旧别名
|
||
"openai": "openai_api_key",
|
||
|
||
# Image
|
||
"cqtai": "cqtai_api_key",
|
||
"image_primary": "image_api_key", # 兼容旧别名
|
||
|
||
# TTS
|
||
"minimax": "minimax_api_key",
|
||
"elevenlabs": "elevenlabs_api_key",
|
||
"edge_tts": "tts_api_key", # EdgeTTS 复用 tts_api_key (通常为空)
|
||
"tts_primary": "tts_api_key", # 兼容旧别名
|
||
}
|
||
|
||
# 轮询计数器
|
||
_round_robin_counters: dict[ProviderType, int] = {
|
||
"text": 0,
|
||
"image": 0,
|
||
"tts": 0,
|
||
}
|
||
|
||
# 延迟缓存(内存中,简化实现)
|
||
_latency_cache: dict[str, float] = {}
|
||
|
||
|
||
def _get_api_key(config_ref: str | None, adapter_name: str) -> str:
|
||
"""根据 config_ref 或适配器名称获取 API Key。"""
|
||
# 优先使用 config_ref
|
||
key_attr = API_KEY_MAP.get(config_ref or adapter_name, None)
|
||
if key_attr:
|
||
return getattr(settings, key_attr, "")
|
||
# 回退到适配器名称
|
||
key_attr = API_KEY_MAP.get(adapter_name, None)
|
||
if key_attr:
|
||
return getattr(settings, key_attr, "")
|
||
return ""
|
||
|
||
|
||
def _get_default_config(adapter_name: str) -> AdapterConfig | None:
|
||
"""获取适配器的默认配置(无 DB 记录时使用)。返回 None 表示未知适配器。"""
|
||
|
||
# --- Text Defaults ---
|
||
if adapter_name in ("gemini", "text_primary"):
|
||
return AdapterConfig(
|
||
api_key=settings.text_api_key,
|
||
model=settings.text_model or "gemini-2.0-flash",
|
||
timeout_ms=60000,
|
||
)
|
||
if adapter_name == "openai":
|
||
return AdapterConfig(
|
||
api_key=getattr(settings, "openai_api_key", ""),
|
||
model="gpt-4o-mini", # 这里可以从 settings 读取,看需求
|
||
timeout_ms=60000,
|
||
)
|
||
|
||
# --- Image Defaults ---
|
||
if adapter_name in ("cqtai"):
|
||
return AdapterConfig(
|
||
api_key=getattr(settings, "cqtai_api_key", ""),
|
||
model="nano-banana-pro", # 默认使用 Pro
|
||
timeout_ms=120000,
|
||
)
|
||
if adapter_name == "image_primary":
|
||
# 如果还有地方在用 image_primary,暂时映射到快或者其他
|
||
# 但既然我们全面整改,最好也删了。这里暂时保留一个空的 fallback 以防报错
|
||
return AdapterConfig(
|
||
api_key=settings.image_api_key,
|
||
timeout_ms=120000
|
||
)
|
||
|
||
# --- TTS Defaults ---
|
||
if adapter_name == "minimax":
|
||
# 传递 group_id 到 Adapter
|
||
# 目前 AdapterConfig 没有 group_id 字段,我们暂时不改 Base,
|
||
# 而是假设 Adapter 会从 config (通过 kwargs 或其他方式) 拿。
|
||
# 实际上我们的 MiniMaxTTSAdapter 还没有处理 group_id。
|
||
# 最简单的方法:把 group_id 藏在 api_base 里或者让 Adapter 自己去 settings 拿。
|
||
# 鉴于 _build_config_from_provider 里我们无法传递额外参数给 Adapter.__init__,
|
||
# 我们这里暂时返回基础配置。
|
||
return AdapterConfig(
|
||
api_key=getattr(settings, "minimax_api_key", ""),
|
||
model="speech-2.6-turbo",
|
||
timeout_ms=60000,
|
||
)
|
||
|
||
if adapter_name == "elevenlabs":
|
||
return AdapterConfig(
|
||
api_key=getattr(settings, "elevenlabs_api_key", ""),
|
||
timeout_ms=120000,
|
||
)
|
||
if adapter_name in ("edge_tts", "tts_primary"):
|
||
return AdapterConfig(
|
||
api_key=settings.tts_api_key,
|
||
api_base=settings.tts_api_base,
|
||
model=settings.tts_model or "zh-CN-XiaoxiaoNeural",
|
||
timeout_ms=120000,
|
||
)
|
||
|
||
# --- Others ---
|
||
if adapter_name in ("storybook_primary", "storybook_gemini"):
|
||
return AdapterConfig(
|
||
api_key=settings.text_api_key, # 复用 Gemini key
|
||
model=settings.text_model,
|
||
timeout_ms=120000,
|
||
)
|
||
|
||
# 未知适配器返回 None
|
||
return None
|
||
|
||
|
||
def _build_config_from_provider(provider: "Provider") -> AdapterConfig:
|
||
"""从 DB Provider 记录构建 AdapterConfig。"""
|
||
api_key = getattr(provider, "api_key", None) or ""
|
||
if not api_key:
|
||
api_key = _get_api_key(provider.config_ref, provider.adapter)
|
||
|
||
default = _get_default_config(provider.adapter)
|
||
if default is None:
|
||
default = AdapterConfig(api_key="", timeout_ms=60000)
|
||
|
||
return AdapterConfig(
|
||
api_key=api_key or default.api_key,
|
||
api_base=provider.api_base or default.api_base,
|
||
model=provider.model or default.model,
|
||
timeout_ms=provider.timeout_ms or default.timeout_ms,
|
||
max_retries=provider.max_retries or default.max_retries,
|
||
extra_config=provider.config_json or {},
|
||
)
|
||
|
||
|
||
async def _get_providers_with_config(
|
||
provider_type: ProviderType,
|
||
) -> list[tuple[str, AdapterConfig, "Provider | None"]]:
|
||
"""获取供应商列表及其配置。
|
||
|
||
Returns:
|
||
[(adapter_name, config, provider_or_none), ...] 按优先级排序
|
||
"""
|
||
db_providers = await get_providers(provider_type)
|
||
|
||
if db_providers:
|
||
return [(p.adapter, _build_config_from_provider(p), p) for p in db_providers]
|
||
|
||
settings_map = {
|
||
"text": settings.text_providers,
|
||
"image": settings.image_providers,
|
||
"tts": settings.tts_providers,
|
||
}
|
||
names = settings_map.get(provider_type) or DEFAULT_PROVIDERS[provider_type]
|
||
result = []
|
||
for name in names:
|
||
config = _get_default_config(name)
|
||
if config is None:
|
||
logger.warning("unknown_adapter_skipped", adapter=name, provider_type=provider_type)
|
||
continue
|
||
result.append((name, config, None))
|
||
return result
|
||
|
||
|
||
def _sort_by_strategy(
|
||
providers: list[tuple[str, AdapterConfig, "Provider | None"]],
|
||
strategy: RoutingStrategy,
|
||
provider_type: ProviderType,
|
||
) -> list[tuple[str, AdapterConfig, "Provider | None"]]:
|
||
"""按策略排序供应商列表。"""
|
||
if strategy == RoutingStrategy.PRIORITY:
|
||
# 按 priority 降序, weight 降序
|
||
return sorted(
|
||
providers,
|
||
key=lambda x: (-(x[2].priority if x[2] else 0), -(x[2].weight if x[2] else 1)),
|
||
)
|
||
|
||
if strategy == RoutingStrategy.COST:
|
||
# 按预估成本升序
|
||
def get_cost(item: tuple[str, AdapterConfig, "Provider | None"]) -> float:
|
||
adapter_class = AdapterRegistry.get(provider_type, item[0])
|
||
if adapter_class:
|
||
try:
|
||
adapter = adapter_class(item[1])
|
||
return adapter.estimated_cost
|
||
except Exception:
|
||
pass
|
||
return float("inf")
|
||
|
||
return sorted(providers, key=get_cost)
|
||
|
||
if strategy == RoutingStrategy.LATENCY:
|
||
# 按历史延迟升序
|
||
def get_latency(item: tuple[str, AdapterConfig, "Provider | None"]) -> float:
|
||
return _latency_cache.get(item[0], float("inf"))
|
||
|
||
return sorted(providers, key=get_latency)
|
||
|
||
if strategy == RoutingStrategy.ROUND_ROBIN:
|
||
# 轮询:旋转列表
|
||
counter = _round_robin_counters[provider_type]
|
||
_round_robin_counters[provider_type] = (counter + 1) % max(len(providers), 1)
|
||
return providers[counter:] + providers[:counter]
|
||
|
||
return providers
|
||
|
||
|
||
async def _route_with_failover(
|
||
provider_type: ProviderType,
|
||
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
|
||
db: AsyncSession | None = None,
|
||
user_id: str | None = None,
|
||
**kwargs,
|
||
) -> T:
|
||
"""通用 provider failover 路由。
|
||
|
||
Args:
|
||
provider_type: 供应商类型 (text/image/tts/storybook)
|
||
strategy: 路由策略
|
||
db: 数据库会话(可选,用于指标收集和熔断检查)
|
||
user_id: 用户 ID(可选,用于成本追踪和预算检查)
|
||
**kwargs: 传递给适配器的参数
|
||
"""
|
||
providers = await _get_providers_with_config(provider_type)
|
||
|
||
if not providers:
|
||
raise ValueError(f"No {provider_type} providers configured.")
|
||
|
||
# 按策略排序
|
||
sorted_providers = _sort_by_strategy(providers, strategy, provider_type)
|
||
|
||
# 如果有 db 会话,过滤掉熔断的供应商
|
||
if db:
|
||
healthy_providers = []
|
||
for item in sorted_providers:
|
||
name, config, db_provider = item
|
||
provider_id = db_provider.id if db_provider else name
|
||
if await health_checker.is_healthy(db, provider_id):
|
||
healthy_providers.append(item)
|
||
else:
|
||
logger.debug("provider_circuit_open", adapter=name, provider_id=provider_id)
|
||
# 如果所有供应商都熔断,仍然尝试第一个(允许恢复)
|
||
if not healthy_providers:
|
||
healthy_providers = sorted_providers[:1]
|
||
sorted_providers = healthy_providers
|
||
|
||
errors: list[str] = []
|
||
for name, config, db_provider in sorted_providers:
|
||
adapter_class = AdapterRegistry.get(provider_type, name)
|
||
if not adapter_class:
|
||
errors.append(f"{name}: 适配器未注册")
|
||
continue
|
||
|
||
provider_id = db_provider.id if db_provider else name
|
||
|
||
try:
|
||
logger.debug(
|
||
"provider_attempt",
|
||
provider_type=provider_type,
|
||
adapter=name,
|
||
strategy=strategy.value,
|
||
)
|
||
|
||
adapter = adapter_class(config)
|
||
|
||
# 执行并计时
|
||
start_time = time.time()
|
||
result = await adapter.execute(**kwargs)
|
||
latency_ms = int((time.time() - start_time) * 1000)
|
||
|
||
# 更新延迟缓存
|
||
_latency_cache[name] = latency_ms
|
||
|
||
# 记录成功指标
|
||
if db:
|
||
await metrics_collector.record_call(
|
||
db,
|
||
provider_id=provider_id,
|
||
success=True,
|
||
latency_ms=latency_ms,
|
||
cost_usd=adapter.estimated_cost,
|
||
)
|
||
await health_checker.record_call_result(db, provider_id, success=True)
|
||
|
||
# 记录用户成本
|
||
if user_id:
|
||
await cost_tracker.record_cost(
|
||
db,
|
||
user_id=user_id,
|
||
provider_name=name,
|
||
capability=provider_type,
|
||
estimated_cost=adapter.estimated_cost,
|
||
provider_id=provider_id if db_provider else None,
|
||
)
|
||
|
||
logger.info(
|
||
"provider_success",
|
||
provider_type=provider_type,
|
||
adapter=name,
|
||
latency_ms=latency_ms,
|
||
)
|
||
return result
|
||
|
||
except Exception as exc:
|
||
error_msg = str(exc)
|
||
logger.warning(
|
||
"provider_failed",
|
||
provider_type=provider_type,
|
||
adapter=name,
|
||
error=error_msg,
|
||
)
|
||
errors.append(f"{name}: {exc}")
|
||
|
||
# 记录失败指标
|
||
if db:
|
||
await metrics_collector.record_call(
|
||
db,
|
||
provider_id=provider_id,
|
||
success=False,
|
||
error_message=error_msg,
|
||
)
|
||
await health_checker.record_call_result(
|
||
db, provider_id, success=False, error=error_msg
|
||
)
|
||
|
||
raise ValueError(f"No {provider_type} provider succeeded. Errors: {' | '.join(errors)}")
|
||
|
||
|
||
async def generate_story_content(
|
||
input_type: Literal["keywords", "full_story"],
|
||
data: str,
|
||
education_theme: str | None = None,
|
||
memory_context: str | None = None,
|
||
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
|
||
db: AsyncSession | None = None,
|
||
) -> StoryOutput:
|
||
"""生成或润色故事,支持 failover。"""
|
||
return await _route_with_failover(
|
||
"text",
|
||
strategy=strategy,
|
||
db=db,
|
||
input_type=input_type,
|
||
data=data,
|
||
education_theme=education_theme,
|
||
memory_context=memory_context,
|
||
)
|
||
|
||
|
||
async def generate_image(
|
||
prompt: str,
|
||
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
|
||
db: AsyncSession | None = None,
|
||
**kwargs,
|
||
) -> str:
|
||
"""生成图片,返回 URL,支持 failover。"""
|
||
return await _route_with_failover("image", strategy=strategy, db=db, prompt=prompt, **kwargs)
|
||
|
||
|
||
async def text_to_speech(
|
||
text: str,
|
||
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
|
||
db: AsyncSession | None = None,
|
||
) -> bytes:
|
||
"""文本转语音,返回 MP3 bytes,支持 failover。"""
|
||
return await _route_with_failover("tts", strategy=strategy, db=db, text=text)
|
||
|
||
|
||
async def generate_storybook(
|
||
keywords: str,
|
||
page_count: int = 6,
|
||
education_theme: str | None = None,
|
||
memory_context: str | None = None,
|
||
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
|
||
db: AsyncSession | None = None,
|
||
):
|
||
"""生成分页故事书,支持 failover。"""
|
||
from app.services.adapters.storybook.primary import Storybook
|
||
|
||
result: Storybook = await _route_with_failover(
|
||
"storybook",
|
||
strategy=strategy,
|
||
db=db,
|
||
keywords=keywords,
|
||
page_count=page_count,
|
||
education_theme=education_theme,
|
||
memory_context=memory_context,
|
||
)
|
||
return result
|