596 lines
20 KiB
Python
596 lines
20 KiB
Python
"""Provider routing with failover - 基于适配器注册表的智能路由。"""
|
||
|
||
import time
|
||
from typing import TYPE_CHECKING, Literal, TypeVar
|
||
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.core.config import settings
|
||
from app.core.logging import get_logger
|
||
from app.services.adapters import AdapterConfig, AdapterRegistry
|
||
from app.services.adapters.text.models import StoryOutput
|
||
from app.services.cost_tracker import cost_tracker
|
||
from app.services.harness.trace import TraceRecorder
|
||
from app.services.provider_cache import get_providers
|
||
from app.services.provider_metrics import health_checker, metrics_collector
|
||
from app.services.provider_policy import (
|
||
API_KEY_MAP,
|
||
DEFAULT_PROVIDERS,
|
||
ProviderType,
|
||
RoutingStrategy,
|
||
get_provider_names_from_settings,
|
||
)
|
||
|
||
if TYPE_CHECKING:
|
||
from app.db.admin_models import Provider
|
||
from app.db.models import GenerationJob
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
T = TypeVar("T")
|
||
|
||
# 轮询计数器
|
||
_round_robin_counters: dict[ProviderType, int] = {
|
||
provider_type: 0 for provider_type in DEFAULT_PROVIDERS
|
||
}
|
||
|
||
# 延迟缓存(内存中,简化实现)
|
||
_latency_cache: dict[str, float] = {}
|
||
|
||
|
||
def _safe_estimated_cost(adapter) -> float:
|
||
"""Return an adapter cost value that is safe to serialize in job events."""
|
||
|
||
try:
|
||
return float(adapter.estimated_cost)
|
||
except Exception:
|
||
return 0.0
|
||
|
||
|
||
async def _record_provider_event_if_present(
|
||
db: AsyncSession | None,
|
||
*,
|
||
job: "GenerationJob | None",
|
||
event_type: str,
|
||
status: str,
|
||
provider_type: ProviderType,
|
||
adapter_name: str,
|
||
strategy: RoutingStrategy,
|
||
provider_id: str | None = None,
|
||
story_id: int | None = None,
|
||
latency_ms: int | None = None,
|
||
estimated_cost: float | None = None,
|
||
error: str | None = None,
|
||
) -> None:
|
||
"""Append provider call telemetry to the active generation job."""
|
||
|
||
if db is None or job is None:
|
||
return
|
||
|
||
await TraceRecorder(db).record_step(
|
||
job=job,
|
||
story_id=story_id,
|
||
event_type=event_type,
|
||
status=status,
|
||
message=(
|
||
f"{provider_type} provider {adapter_name} {status}."
|
||
if error is None
|
||
else f"{provider_type} provider {adapter_name} failed."
|
||
),
|
||
metadata={
|
||
"capability": provider_type,
|
||
"adapter": adapter_name,
|
||
"provider_id": provider_id,
|
||
"strategy": strategy.value,
|
||
"latency_ms": latency_ms,
|
||
"estimated_cost_usd": estimated_cost,
|
||
"error": error,
|
||
},
|
||
)
|
||
|
||
|
||
def _get_api_key(config_ref: str | None, adapter_name: str) -> str:
|
||
"""根据 config_ref 或适配器名称获取 API Key。"""
|
||
# 优先使用 config_ref
|
||
key_attr = API_KEY_MAP.get(config_ref or adapter_name, None)
|
||
if key_attr:
|
||
return getattr(settings, key_attr, "")
|
||
# 回退到适配器名称
|
||
key_attr = API_KEY_MAP.get(adapter_name, None)
|
||
if key_attr:
|
||
return getattr(settings, key_attr, "")
|
||
return ""
|
||
|
||
|
||
def _get_default_config(adapter_name: str) -> AdapterConfig | None:
|
||
"""获取适配器的默认配置(无 DB 记录时使用)。返回 None 表示未知适配器。"""
|
||
|
||
if adapter_name == "demo":
|
||
return AdapterConfig(
|
||
api_key="",
|
||
model="demo",
|
||
timeout_ms=1000,
|
||
)
|
||
|
||
# --- ASR Defaults ---
|
||
if adapter_name == "openai_asr":
|
||
return AdapterConfig(
|
||
api_key=settings.openai_api_key,
|
||
api_base=getattr(settings, "openai_api_base", ""),
|
||
model=settings.voice_transcription_model,
|
||
timeout_ms=60000,
|
||
)
|
||
|
||
# --- Text Defaults ---
|
||
if adapter_name in ("gemini", "text_primary"):
|
||
return AdapterConfig(
|
||
api_key=settings.text_api_key,
|
||
model=settings.text_model or "gemini-2.0-flash",
|
||
timeout_ms=60000,
|
||
)
|
||
if adapter_name == "openai":
|
||
return AdapterConfig(
|
||
api_key=getattr(settings, "openai_api_key", ""),
|
||
api_base=getattr(settings, "openai_api_base", ""),
|
||
model=settings.openai_model,
|
||
timeout_ms=60000,
|
||
)
|
||
|
||
# --- Image Defaults ---
|
||
if adapter_name == "cqtai":
|
||
return AdapterConfig(
|
||
api_key=getattr(settings, "cqtai_api_key", ""),
|
||
model=settings.image_model or "nano-banana-pro",
|
||
timeout_ms=120000,
|
||
)
|
||
if adapter_name == "antigravity":
|
||
return AdapterConfig(
|
||
api_key=getattr(settings, "antigravity_api_key", ""),
|
||
api_base=getattr(settings, "antigravity_api_base", ""),
|
||
model=settings.antigravity_model,
|
||
timeout_ms=120000,
|
||
)
|
||
if adapter_name == "image_primary":
|
||
# 如果还有地方在用 image_primary,暂时映射到快或者其他
|
||
# 但既然我们全面整改,最好也删了。这里暂时保留一个空的 fallback 以防报错
|
||
return AdapterConfig(
|
||
api_key=settings.image_api_key,
|
||
timeout_ms=120000
|
||
)
|
||
|
||
# --- TTS Defaults ---
|
||
if adapter_name == "minimax":
|
||
# 传递 group_id 到 Adapter
|
||
# 目前 AdapterConfig 没有 group_id 字段,我们暂时不改 Base,
|
||
# 而是假设 Adapter 会从 config (通过 kwargs 或其他方式) 拿。
|
||
# 实际上我们的 MiniMaxTTSAdapter 还没有处理 group_id。
|
||
# 最简单的方法:把 group_id 藏在 api_base 里或者让 Adapter 自己去 settings 拿。
|
||
# 鉴于 _build_config_from_provider 里我们无法传递额外参数给 Adapter.__init__,
|
||
# 我们这里暂时返回基础配置。
|
||
return AdapterConfig(
|
||
api_key=getattr(settings, "minimax_api_key", ""),
|
||
model=settings.tts_minimax_model,
|
||
timeout_ms=60000,
|
||
)
|
||
|
||
if adapter_name == "elevenlabs":
|
||
return AdapterConfig(
|
||
api_key=getattr(settings, "elevenlabs_api_key", ""),
|
||
timeout_ms=120000,
|
||
)
|
||
if adapter_name in ("edge_tts", "tts_primary"):
|
||
return AdapterConfig(
|
||
api_key=settings.tts_api_key,
|
||
api_base=settings.tts_api_base,
|
||
model=settings.tts_model or "zh-CN-XiaoxiaoNeural",
|
||
timeout_ms=120000,
|
||
)
|
||
|
||
# --- Others ---
|
||
if adapter_name in ("storybook_primary", "storybook_gemini"):
|
||
return AdapterConfig(
|
||
api_key=settings.text_api_key, # 复用 Gemini key
|
||
model=settings.text_model,
|
||
timeout_ms=120000,
|
||
)
|
||
|
||
# 未知适配器返回 None
|
||
return None
|
||
|
||
|
||
def _build_config_from_provider(provider: "Provider") -> AdapterConfig:
|
||
"""从 DB Provider 记录构建 AdapterConfig。"""
|
||
api_key = getattr(provider, "api_key", None) or ""
|
||
if not api_key:
|
||
api_key = _get_api_key(provider.config_ref, provider.adapter)
|
||
|
||
default = _get_default_config(provider.adapter)
|
||
if default is None:
|
||
default = AdapterConfig(api_key="", timeout_ms=60000)
|
||
|
||
return AdapterConfig(
|
||
api_key=api_key or default.api_key,
|
||
api_base=provider.api_base or default.api_base,
|
||
model=provider.model or default.model,
|
||
timeout_ms=provider.timeout_ms or default.timeout_ms,
|
||
max_retries=provider.max_retries or default.max_retries,
|
||
extra_config=provider.config_json or {},
|
||
)
|
||
|
||
|
||
async def _get_providers_with_config(
|
||
provider_type: ProviderType,
|
||
) -> list[tuple[str, AdapterConfig, "Provider | None"]]:
|
||
"""获取供应商列表及其配置。
|
||
|
||
Returns:
|
||
[(adapter_name, config, provider_or_none), ...] 按优先级排序
|
||
"""
|
||
db_providers = await get_providers(provider_type)
|
||
|
||
if db_providers:
|
||
return [(p.adapter, _build_config_from_provider(p), p) for p in db_providers]
|
||
|
||
names = get_provider_names_from_settings(provider_type, settings)
|
||
|
||
result = []
|
||
for name in names:
|
||
config = _get_default_config(name)
|
||
if config is None:
|
||
logger.warning("unknown_adapter_skipped", adapter=name, provider_type=provider_type)
|
||
continue
|
||
result.append((name, config, None))
|
||
return result
|
||
|
||
|
||
def _sort_by_strategy(
|
||
providers: list[tuple[str, AdapterConfig, "Provider | None"]],
|
||
strategy: RoutingStrategy,
|
||
provider_type: ProviderType,
|
||
) -> list[tuple[str, AdapterConfig, "Provider | None"]]:
|
||
"""按策略排序供应商列表。"""
|
||
if strategy == RoutingStrategy.PRIORITY:
|
||
# 按 priority 降序, weight 降序
|
||
return sorted(
|
||
providers,
|
||
key=lambda x: (-(x[2].priority if x[2] else 0), -(x[2].weight if x[2] else 1)),
|
||
)
|
||
|
||
if strategy == RoutingStrategy.COST:
|
||
# 按预估成本升序
|
||
def get_cost(item: tuple[str, AdapterConfig, "Provider | None"]) -> float:
|
||
adapter_class = AdapterRegistry.get(provider_type, item[0])
|
||
if adapter_class:
|
||
try:
|
||
adapter = adapter_class(item[1])
|
||
return adapter.estimated_cost
|
||
except Exception:
|
||
pass
|
||
return float("inf")
|
||
|
||
return sorted(providers, key=get_cost)
|
||
|
||
if strategy == RoutingStrategy.LATENCY:
|
||
# 按历史延迟升序
|
||
def get_latency(item: tuple[str, AdapterConfig, "Provider | None"]) -> float:
|
||
return _latency_cache.get(item[0], float("inf"))
|
||
|
||
return sorted(providers, key=get_latency)
|
||
|
||
if strategy == RoutingStrategy.ROUND_ROBIN:
|
||
# 轮询:旋转列表
|
||
counter = _round_robin_counters[provider_type]
|
||
_round_robin_counters[provider_type] = (counter + 1) % max(len(providers), 1)
|
||
return providers[counter:] + providers[:counter]
|
||
|
||
return providers
|
||
|
||
|
||
async def _route_with_failover(
|
||
provider_type: ProviderType,
|
||
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
|
||
db: AsyncSession | None = None,
|
||
user_id: str | None = None,
|
||
generation_job: "GenerationJob | None" = None,
|
||
story_id: int | None = None,
|
||
**kwargs,
|
||
) -> T:
|
||
"""通用 provider failover 路由。
|
||
|
||
Args:
|
||
provider_type: 供应商类型 (text/image/tts/storybook/asr)
|
||
strategy: 路由策略
|
||
db: 数据库会话(可选,用于指标收集和熔断检查)
|
||
user_id: 用户 ID(可选,用于成本追踪和预算检查)
|
||
generation_job: 生成任务(可选,用于记录 provider 调用轨迹)
|
||
story_id: 故事 ID(可选,用于关联 provider 事件)
|
||
**kwargs: 传递给适配器的参数
|
||
"""
|
||
provider_names = kwargs.pop("provider_names", None)
|
||
if provider_names:
|
||
providers = [
|
||
(name, _get_default_config(name) or AdapterConfig(api_key=""), None)
|
||
for name in provider_names
|
||
]
|
||
else:
|
||
providers = await _get_providers_with_config(provider_type)
|
||
|
||
if not providers:
|
||
raise ValueError(f"No {provider_type} providers configured.")
|
||
|
||
# 按策略排序
|
||
sorted_providers = _sort_by_strategy(providers, strategy, provider_type)
|
||
|
||
# 如果有 db 会话,过滤掉后台管理台中已熔断的供应商。
|
||
# .env/default provider 没有 providers 表记录,不能写入带外键的健康表。
|
||
if db:
|
||
healthy_providers = []
|
||
for item in sorted_providers:
|
||
name, config, db_provider = item
|
||
if db_provider is None:
|
||
healthy_providers.append(item)
|
||
continue
|
||
|
||
provider_id = db_provider.id
|
||
if await health_checker.is_healthy(db, provider_id):
|
||
healthy_providers.append(item)
|
||
else:
|
||
logger.debug("provider_circuit_open", adapter=name, provider_id=provider_id)
|
||
# 如果所有供应商都熔断,仍然尝试第一个(允许恢复)
|
||
if not healthy_providers:
|
||
healthy_providers = sorted_providers[:1]
|
||
sorted_providers = healthy_providers
|
||
|
||
errors: list[str] = []
|
||
for name, config, db_provider in sorted_providers:
|
||
adapter_class = AdapterRegistry.get(provider_type, name)
|
||
if not adapter_class:
|
||
errors.append(f"{name}: 适配器未注册")
|
||
continue
|
||
|
||
provider_id = str(db_provider.id) if db_provider else None
|
||
estimated_cost: float | None = None
|
||
start_time: float | None = None
|
||
|
||
try:
|
||
logger.debug(
|
||
"provider_attempt",
|
||
provider_type=provider_type,
|
||
adapter=name,
|
||
strategy=strategy.value,
|
||
)
|
||
|
||
adapter = adapter_class(config)
|
||
estimated_cost = _safe_estimated_cost(adapter)
|
||
|
||
await _record_provider_event_if_present(
|
||
db,
|
||
job=generation_job,
|
||
story_id=story_id,
|
||
event_type="provider_call_started",
|
||
status="running",
|
||
provider_type=provider_type,
|
||
adapter_name=name,
|
||
provider_id=provider_id,
|
||
strategy=strategy,
|
||
estimated_cost=estimated_cost,
|
||
)
|
||
|
||
# 执行并计时
|
||
start_time = time.time()
|
||
result = await adapter.execute(**kwargs)
|
||
latency_ms = int((time.time() - start_time) * 1000)
|
||
|
||
# 更新延迟缓存
|
||
_latency_cache[name] = latency_ms
|
||
|
||
# 记录成功指标。Provider 指标/健康表带外键,只记录后台管理台里的真实 provider。
|
||
if db and db_provider and provider_id:
|
||
await metrics_collector.record_call(
|
||
db,
|
||
provider_id=provider_id,
|
||
success=True,
|
||
latency_ms=latency_ms,
|
||
cost_usd=estimated_cost,
|
||
)
|
||
await health_checker.record_call_result(db, provider_id, success=True)
|
||
|
||
# 记录用户成本;环境变量/default provider 没有 provider_id,保留 provider_name 即可。
|
||
if db and user_id:
|
||
await cost_tracker.record_cost(
|
||
db,
|
||
user_id=user_id,
|
||
provider_name=name,
|
||
capability=provider_type,
|
||
estimated_cost=estimated_cost,
|
||
provider_id=provider_id,
|
||
)
|
||
|
||
await _record_provider_event_if_present(
|
||
db,
|
||
job=generation_job,
|
||
story_id=story_id,
|
||
event_type="provider_call_succeeded",
|
||
status="succeeded",
|
||
provider_type=provider_type,
|
||
adapter_name=name,
|
||
provider_id=provider_id,
|
||
strategy=strategy,
|
||
latency_ms=latency_ms,
|
||
estimated_cost=estimated_cost,
|
||
)
|
||
|
||
logger.info(
|
||
"provider_success",
|
||
provider_type=provider_type,
|
||
adapter=name,
|
||
latency_ms=latency_ms,
|
||
)
|
||
return result
|
||
|
||
except Exception as exc:
|
||
error_msg = str(exc)
|
||
latency_ms = (
|
||
int((time.time() - start_time) * 1000)
|
||
if start_time is not None
|
||
else None
|
||
)
|
||
logger.warning(
|
||
"provider_failed",
|
||
provider_type=provider_type,
|
||
adapter=name,
|
||
error=error_msg,
|
||
)
|
||
errors.append(f"{name}: {exc}")
|
||
|
||
# 记录失败指标。Provider 指标/健康表带外键,只记录后台管理台里的真实 provider。
|
||
if db and db_provider and provider_id:
|
||
await metrics_collector.record_call(
|
||
db,
|
||
provider_id=provider_id,
|
||
success=False,
|
||
error_message=error_msg,
|
||
)
|
||
await health_checker.record_call_result(
|
||
db, provider_id, success=False, error=error_msg
|
||
)
|
||
|
||
await _record_provider_event_if_present(
|
||
db,
|
||
job=generation_job,
|
||
story_id=story_id,
|
||
event_type="provider_call_failed",
|
||
status="failed",
|
||
provider_type=provider_type,
|
||
adapter_name=name,
|
||
provider_id=provider_id,
|
||
strategy=strategy,
|
||
latency_ms=latency_ms,
|
||
estimated_cost=estimated_cost,
|
||
error=error_msg,
|
||
)
|
||
|
||
raise ValueError(f"No {provider_type} provider succeeded. Errors: {' | '.join(errors)}")
|
||
|
||
|
||
async def transcribe_audio(
|
||
audio_bytes: bytes,
|
||
file_name: str | None = None,
|
||
mime_type: str | None = None,
|
||
transcript_hint: str | None = None,
|
||
language: str | None = None,
|
||
provider_names: list[str] | None = None,
|
||
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
|
||
db: AsyncSession | None = None,
|
||
user_id: str | None = None,
|
||
):
|
||
"""语音转写,支持 provider failover。"""
|
||
from app.services.adapters.asr.models import TranscriptionOutput
|
||
|
||
result: TranscriptionOutput = await _route_with_failover(
|
||
"asr",
|
||
strategy=strategy,
|
||
db=db,
|
||
user_id=user_id,
|
||
audio_bytes=audio_bytes,
|
||
file_name=file_name,
|
||
mime_type=mime_type,
|
||
transcript_hint=transcript_hint,
|
||
language=language,
|
||
provider_names=provider_names,
|
||
)
|
||
return result
|
||
|
||
|
||
async def generate_story_content(
|
||
input_type: Literal["keywords", "full_story"],
|
||
data: str,
|
||
education_theme: str | None = None,
|
||
memory_context: str | None = None,
|
||
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
|
||
db: AsyncSession | None = None,
|
||
user_id: str | None = None,
|
||
generation_job: "GenerationJob | None" = None,
|
||
) -> StoryOutput:
|
||
"""生成或润色故事,支持 failover。"""
|
||
return await _route_with_failover(
|
||
"text",
|
||
strategy=strategy,
|
||
db=db,
|
||
user_id=user_id,
|
||
generation_job=generation_job,
|
||
input_type=input_type,
|
||
data=data,
|
||
education_theme=education_theme,
|
||
memory_context=memory_context,
|
||
)
|
||
|
||
|
||
async def generate_image(
|
||
prompt: str,
|
||
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
|
||
db: AsyncSession | None = None,
|
||
user_id: str | None = None,
|
||
generation_job: "GenerationJob | None" = None,
|
||
story_id: int | None = None,
|
||
**kwargs,
|
||
) -> str:
|
||
"""生成图片,返回 URL,支持 failover。"""
|
||
return await _route_with_failover(
|
||
"image",
|
||
strategy=strategy,
|
||
db=db,
|
||
user_id=user_id,
|
||
generation_job=generation_job,
|
||
story_id=story_id,
|
||
prompt=prompt,
|
||
**kwargs,
|
||
)
|
||
|
||
|
||
async def text_to_speech(
|
||
text: str,
|
||
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
|
||
db: AsyncSession | None = None,
|
||
user_id: str | None = None,
|
||
generation_job: "GenerationJob | None" = None,
|
||
story_id: int | None = None,
|
||
) -> bytes:
|
||
"""文本转语音,返回 MP3 bytes,支持 failover。"""
|
||
return await _route_with_failover(
|
||
"tts",
|
||
strategy=strategy,
|
||
db=db,
|
||
user_id=user_id,
|
||
generation_job=generation_job,
|
||
story_id=story_id,
|
||
text=text,
|
||
)
|
||
|
||
|
||
async def generate_storybook(
|
||
keywords: str,
|
||
page_count: int = 6,
|
||
education_theme: str | None = None,
|
||
memory_context: str | None = None,
|
||
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
|
||
db: AsyncSession | None = None,
|
||
user_id: str | None = None,
|
||
generation_job: "GenerationJob | None" = None,
|
||
):
|
||
"""生成分页故事书,支持 failover。"""
|
||
from app.services.adapters.storybook.primary import Storybook
|
||
|
||
result: Storybook = await _route_with_failover(
|
||
"storybook",
|
||
strategy=strategy,
|
||
db=db,
|
||
user_id=user_id,
|
||
generation_job=generation_job,
|
||
keywords=keywords,
|
||
page_count=page_count,
|
||
education_theme=education_theme,
|
||
memory_context=memory_context,
|
||
)
|
||
return result
|