Files
dreamweaver/backend/app/services/provider_router.py

597 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Provider routing with failover - 基于适配器注册表的智能路由。"""
import time
from typing import TYPE_CHECKING, Literal, TypeVar
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
from app.services.adapters import AdapterConfig, AdapterRegistry
from app.services.adapters.text.models import StoryOutput
from app.services.cost_tracker import cost_tracker
from app.services.generation_jobs import record_generation_event
from app.services.provider_cache import get_providers
from app.services.provider_metrics import health_checker, metrics_collector
from app.services.provider_policy import (
API_KEY_MAP,
DEFAULT_PROVIDERS,
ProviderType,
RoutingStrategy,
get_provider_names_from_settings,
)
if TYPE_CHECKING:
from app.db.admin_models import Provider
from app.db.models import GenerationJob
logger = get_logger(__name__)
T = TypeVar("T")
# 轮询计数器
_round_robin_counters: dict[ProviderType, int] = {
provider_type: 0 for provider_type in DEFAULT_PROVIDERS
}
# 延迟缓存(内存中,简化实现)
_latency_cache: dict[str, float] = {}
def _safe_estimated_cost(adapter) -> float:
"""Return an adapter cost value that is safe to serialize in job events."""
try:
return float(adapter.estimated_cost)
except Exception:
return 0.0
async def _record_provider_event_if_present(
db: AsyncSession | None,
*,
job: "GenerationJob | None",
event_type: str,
status: str,
provider_type: ProviderType,
adapter_name: str,
strategy: RoutingStrategy,
provider_id: str | None = None,
story_id: int | None = None,
latency_ms: int | None = None,
estimated_cost: float | None = None,
error: str | None = None,
) -> None:
"""Append provider call telemetry to the active generation job."""
if db is None or job is None:
return
await record_generation_event(
db,
job=job,
story_id=story_id,
event_type=event_type,
status=status,
message=(
f"{provider_type} provider {adapter_name} {status}."
if error is None
else f"{provider_type} provider {adapter_name} failed."
),
metadata={
"capability": provider_type,
"adapter": adapter_name,
"provider_id": provider_id,
"strategy": strategy.value,
"latency_ms": latency_ms,
"estimated_cost_usd": estimated_cost,
"error": error,
},
)
def _get_api_key(config_ref: str | None, adapter_name: str) -> str:
"""根据 config_ref 或适配器名称获取 API Key。"""
# 优先使用 config_ref
key_attr = API_KEY_MAP.get(config_ref or adapter_name, None)
if key_attr:
return getattr(settings, key_attr, "")
# 回退到适配器名称
key_attr = API_KEY_MAP.get(adapter_name, None)
if key_attr:
return getattr(settings, key_attr, "")
return ""
def _get_default_config(adapter_name: str) -> AdapterConfig | None:
"""获取适配器的默认配置(无 DB 记录时使用)。返回 None 表示未知适配器。"""
if adapter_name == "demo":
return AdapterConfig(
api_key="",
model="demo",
timeout_ms=1000,
)
# --- ASR Defaults ---
if adapter_name == "openai_asr":
return AdapterConfig(
api_key=settings.openai_api_key,
api_base=getattr(settings, "openai_api_base", ""),
model=settings.voice_transcription_model,
timeout_ms=60000,
)
# --- Text Defaults ---
if adapter_name in ("gemini", "text_primary"):
return AdapterConfig(
api_key=settings.text_api_key,
model=settings.text_model or "gemini-2.0-flash",
timeout_ms=60000,
)
if adapter_name == "openai":
return AdapterConfig(
api_key=getattr(settings, "openai_api_key", ""),
api_base=getattr(settings, "openai_api_base", ""),
model=settings.openai_model,
timeout_ms=60000,
)
# --- Image Defaults ---
if adapter_name == "cqtai":
return AdapterConfig(
api_key=getattr(settings, "cqtai_api_key", ""),
model=settings.image_model or "nano-banana-pro",
timeout_ms=120000,
)
if adapter_name == "antigravity":
return AdapterConfig(
api_key=getattr(settings, "antigravity_api_key", ""),
api_base=getattr(settings, "antigravity_api_base", ""),
model=settings.antigravity_model,
timeout_ms=120000,
)
if adapter_name == "image_primary":
# 如果还有地方在用 image_primary暂时映射到快或者其他
# 但既然我们全面整改,最好也删了。这里暂时保留一个空的 fallback 以防报错
return AdapterConfig(
api_key=settings.image_api_key,
timeout_ms=120000
)
# --- TTS Defaults ---
if adapter_name == "minimax":
# 传递 group_id 到 Adapter
# 目前 AdapterConfig 没有 group_id 字段,我们暂时不改 Base
# 而是假设 Adapter 会从 config (通过 kwargs 或其他方式) 拿。
# 实际上我们的 MiniMaxTTSAdapter 还没有处理 group_id。
# 最简单的方法:把 group_id 藏在 api_base 里或者让 Adapter 自己去 settings 拿。
# 鉴于 _build_config_from_provider 里我们无法传递额外参数给 Adapter.__init__
# 我们这里暂时返回基础配置。
return AdapterConfig(
api_key=getattr(settings, "minimax_api_key", ""),
model=settings.tts_minimax_model,
timeout_ms=60000,
)
if adapter_name == "elevenlabs":
return AdapterConfig(
api_key=getattr(settings, "elevenlabs_api_key", ""),
timeout_ms=120000,
)
if adapter_name in ("edge_tts", "tts_primary"):
return AdapterConfig(
api_key=settings.tts_api_key,
api_base=settings.tts_api_base,
model=settings.tts_model or "zh-CN-XiaoxiaoNeural",
timeout_ms=120000,
)
# --- Others ---
if adapter_name in ("storybook_primary", "storybook_gemini"):
return AdapterConfig(
api_key=settings.text_api_key, # 复用 Gemini key
model=settings.text_model,
timeout_ms=120000,
)
# 未知适配器返回 None
return None
def _build_config_from_provider(provider: "Provider") -> AdapterConfig:
"""从 DB Provider 记录构建 AdapterConfig。"""
api_key = getattr(provider, "api_key", None) or ""
if not api_key:
api_key = _get_api_key(provider.config_ref, provider.adapter)
default = _get_default_config(provider.adapter)
if default is None:
default = AdapterConfig(api_key="", timeout_ms=60000)
return AdapterConfig(
api_key=api_key or default.api_key,
api_base=provider.api_base or default.api_base,
model=provider.model or default.model,
timeout_ms=provider.timeout_ms or default.timeout_ms,
max_retries=provider.max_retries or default.max_retries,
extra_config=provider.config_json or {},
)
async def _get_providers_with_config(
provider_type: ProviderType,
) -> list[tuple[str, AdapterConfig, "Provider | None"]]:
"""获取供应商列表及其配置。
Returns:
[(adapter_name, config, provider_or_none), ...] 按优先级排序
"""
db_providers = await get_providers(provider_type)
if db_providers:
return [(p.adapter, _build_config_from_provider(p), p) for p in db_providers]
names = get_provider_names_from_settings(provider_type, settings)
result = []
for name in names:
config = _get_default_config(name)
if config is None:
logger.warning("unknown_adapter_skipped", adapter=name, provider_type=provider_type)
continue
result.append((name, config, None))
return result
def _sort_by_strategy(
providers: list[tuple[str, AdapterConfig, "Provider | None"]],
strategy: RoutingStrategy,
provider_type: ProviderType,
) -> list[tuple[str, AdapterConfig, "Provider | None"]]:
"""按策略排序供应商列表。"""
if strategy == RoutingStrategy.PRIORITY:
# 按 priority 降序, weight 降序
return sorted(
providers,
key=lambda x: (-(x[2].priority if x[2] else 0), -(x[2].weight if x[2] else 1)),
)
if strategy == RoutingStrategy.COST:
# 按预估成本升序
def get_cost(item: tuple[str, AdapterConfig, "Provider | None"]) -> float:
adapter_class = AdapterRegistry.get(provider_type, item[0])
if adapter_class:
try:
adapter = adapter_class(item[1])
return adapter.estimated_cost
except Exception:
pass
return float("inf")
return sorted(providers, key=get_cost)
if strategy == RoutingStrategy.LATENCY:
# 按历史延迟升序
def get_latency(item: tuple[str, AdapterConfig, "Provider | None"]) -> float:
return _latency_cache.get(item[0], float("inf"))
return sorted(providers, key=get_latency)
if strategy == RoutingStrategy.ROUND_ROBIN:
# 轮询:旋转列表
counter = _round_robin_counters[provider_type]
_round_robin_counters[provider_type] = (counter + 1) % max(len(providers), 1)
return providers[counter:] + providers[:counter]
return providers
async def _route_with_failover(
provider_type: ProviderType,
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
user_id: str | None = None,
generation_job: "GenerationJob | None" = None,
story_id: int | None = None,
**kwargs,
) -> T:
"""通用 provider failover 路由。
Args:
provider_type: 供应商类型 (text/image/tts/storybook/asr)
strategy: 路由策略
db: 数据库会话(可选,用于指标收集和熔断检查)
user_id: 用户 ID可选用于成本追踪和预算检查
generation_job: 生成任务(可选,用于记录 provider 调用轨迹)
story_id: 故事 ID可选用于关联 provider 事件)
**kwargs: 传递给适配器的参数
"""
provider_names = kwargs.pop("provider_names", None)
if provider_names:
providers = [
(name, _get_default_config(name) or AdapterConfig(api_key=""), None)
for name in provider_names
]
else:
providers = await _get_providers_with_config(provider_type)
if not providers:
raise ValueError(f"No {provider_type} providers configured.")
# 按策略排序
sorted_providers = _sort_by_strategy(providers, strategy, provider_type)
# 如果有 db 会话,过滤掉后台管理台中已熔断的供应商。
# .env/default provider 没有 providers 表记录,不能写入带外键的健康表。
if db:
healthy_providers = []
for item in sorted_providers:
name, config, db_provider = item
if db_provider is None:
healthy_providers.append(item)
continue
provider_id = db_provider.id
if await health_checker.is_healthy(db, provider_id):
healthy_providers.append(item)
else:
logger.debug("provider_circuit_open", adapter=name, provider_id=provider_id)
# 如果所有供应商都熔断,仍然尝试第一个(允许恢复)
if not healthy_providers:
healthy_providers = sorted_providers[:1]
sorted_providers = healthy_providers
errors: list[str] = []
for name, config, db_provider in sorted_providers:
adapter_class = AdapterRegistry.get(provider_type, name)
if not adapter_class:
errors.append(f"{name}: 适配器未注册")
continue
provider_id = str(db_provider.id) if db_provider else None
estimated_cost: float | None = None
start_time: float | None = None
try:
logger.debug(
"provider_attempt",
provider_type=provider_type,
adapter=name,
strategy=strategy.value,
)
adapter = adapter_class(config)
estimated_cost = _safe_estimated_cost(adapter)
await _record_provider_event_if_present(
db,
job=generation_job,
story_id=story_id,
event_type="provider_call_started",
status="running",
provider_type=provider_type,
adapter_name=name,
provider_id=provider_id,
strategy=strategy,
estimated_cost=estimated_cost,
)
# 执行并计时
start_time = time.time()
result = await adapter.execute(**kwargs)
latency_ms = int((time.time() - start_time) * 1000)
# 更新延迟缓存
_latency_cache[name] = latency_ms
# 记录成功指标。Provider 指标/健康表带外键,只记录后台管理台里的真实 provider。
if db and db_provider and provider_id:
await metrics_collector.record_call(
db,
provider_id=provider_id,
success=True,
latency_ms=latency_ms,
cost_usd=estimated_cost,
)
await health_checker.record_call_result(db, provider_id, success=True)
# 记录用户成本;环境变量/default provider 没有 provider_id保留 provider_name 即可。
if db and user_id:
await cost_tracker.record_cost(
db,
user_id=user_id,
provider_name=name,
capability=provider_type,
estimated_cost=estimated_cost,
provider_id=provider_id,
)
await _record_provider_event_if_present(
db,
job=generation_job,
story_id=story_id,
event_type="provider_call_succeeded",
status="succeeded",
provider_type=provider_type,
adapter_name=name,
provider_id=provider_id,
strategy=strategy,
latency_ms=latency_ms,
estimated_cost=estimated_cost,
)
logger.info(
"provider_success",
provider_type=provider_type,
adapter=name,
latency_ms=latency_ms,
)
return result
except Exception as exc:
error_msg = str(exc)
latency_ms = (
int((time.time() - start_time) * 1000)
if start_time is not None
else None
)
logger.warning(
"provider_failed",
provider_type=provider_type,
adapter=name,
error=error_msg,
)
errors.append(f"{name}: {exc}")
# 记录失败指标。Provider 指标/健康表带外键,只记录后台管理台里的真实 provider。
if db and db_provider and provider_id:
await metrics_collector.record_call(
db,
provider_id=provider_id,
success=False,
error_message=error_msg,
)
await health_checker.record_call_result(
db, provider_id, success=False, error=error_msg
)
await _record_provider_event_if_present(
db,
job=generation_job,
story_id=story_id,
event_type="provider_call_failed",
status="failed",
provider_type=provider_type,
adapter_name=name,
provider_id=provider_id,
strategy=strategy,
latency_ms=latency_ms,
estimated_cost=estimated_cost,
error=error_msg,
)
raise ValueError(f"No {provider_type} provider succeeded. Errors: {' | '.join(errors)}")
async def transcribe_audio(
audio_bytes: bytes,
file_name: str | None = None,
mime_type: str | None = None,
transcript_hint: str | None = None,
language: str | None = None,
provider_names: list[str] | None = None,
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
user_id: str | None = None,
):
"""语音转写,支持 provider failover。"""
from app.services.adapters.asr.models import TranscriptionOutput
result: TranscriptionOutput = await _route_with_failover(
"asr",
strategy=strategy,
db=db,
user_id=user_id,
audio_bytes=audio_bytes,
file_name=file_name,
mime_type=mime_type,
transcript_hint=transcript_hint,
language=language,
provider_names=provider_names,
)
return result
async def generate_story_content(
input_type: Literal["keywords", "full_story"],
data: str,
education_theme: str | None = None,
memory_context: str | None = None,
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
user_id: str | None = None,
generation_job: "GenerationJob | None" = None,
) -> StoryOutput:
"""生成或润色故事,支持 failover。"""
return await _route_with_failover(
"text",
strategy=strategy,
db=db,
user_id=user_id,
generation_job=generation_job,
input_type=input_type,
data=data,
education_theme=education_theme,
memory_context=memory_context,
)
async def generate_image(
prompt: str,
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
user_id: str | None = None,
generation_job: "GenerationJob | None" = None,
story_id: int | None = None,
**kwargs,
) -> str:
"""生成图片,返回 URL支持 failover。"""
return await _route_with_failover(
"image",
strategy=strategy,
db=db,
user_id=user_id,
generation_job=generation_job,
story_id=story_id,
prompt=prompt,
**kwargs,
)
async def text_to_speech(
text: str,
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
user_id: str | None = None,
generation_job: "GenerationJob | None" = None,
story_id: int | None = None,
) -> bytes:
"""文本转语音,返回 MP3 bytes支持 failover。"""
return await _route_with_failover(
"tts",
strategy=strategy,
db=db,
user_id=user_id,
generation_job=generation_job,
story_id=story_id,
text=text,
)
async def generate_storybook(
keywords: str,
page_count: int = 6,
education_theme: str | None = None,
memory_context: str | None = None,
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
user_id: str | None = None,
generation_job: "GenerationJob | None" = None,
):
"""生成分页故事书,支持 failover。"""
from app.services.adapters.storybook.primary import Storybook
result: Storybook = await _route_with_failover(
"storybook",
strategy=strategy,
db=db,
user_id=user_id,
generation_job=generation_job,
keywords=keywords,
page_count=page_count,
education_theme=education_theme,
memory_context=memory_context,
)
return result