Initial commit: clean project structure
- Backend: FastAPI + SQLAlchemy + Celery (Python 3.11+) - Frontend: Vue 3 + TypeScript + Pinia + Tailwind - Admin Frontend: separate Vue 3 app for management - Docker Compose: 9 services orchestration - Specs: design prototypes, memory system PRD, product roadmap Cleanup performed: - Removed temporary debug scripts from backend root - Removed deprecated admin_app.py (embedded UI) - Removed duplicate docs from admin-frontend - Updated .gitignore for Vite cache and egg-info
This commit is contained in:
432
backend/app/services/provider_router.py
Normal file
432
backend/app/services/provider_router.py
Normal file
@@ -0,0 +1,432 @@
|
||||
"""Provider routing with failover - 基于适配器注册表的智能路由。"""
|
||||
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Literal, TypeVar
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
from app.services.adapters import AdapterConfig, AdapterRegistry
|
||||
from app.services.adapters.text.models import StoryOutput
|
||||
from app.services.cost_tracker import cost_tracker
|
||||
from app.services.provider_cache import get_providers
|
||||
from app.services.provider_metrics import health_checker, metrics_collector
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.db.admin_models import Provider
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
ProviderType = Literal["text", "image", "tts", "storybook"]
|
||||
|
||||
|
||||
class RoutingStrategy(str, Enum):
|
||||
"""路由策略枚举。"""
|
||||
|
||||
PRIORITY = "priority" # 按优先级排序(默认)
|
||||
COST = "cost" # 按成本排序
|
||||
LATENCY = "latency" # 按延迟排序
|
||||
ROUND_ROBIN = "round_robin" # 轮询
|
||||
|
||||
|
||||
# 默认配置映射(当 DB 无配置时使用)
|
||||
# 默认配置映射(当 DB 无配置时使用)
|
||||
# 这是“代码级”的默认策略,对应 .env 为空的情况
|
||||
DEFAULT_PROVIDERS: dict[ProviderType, list[str]] = {
|
||||
"text": ["gemini", "openai"],
|
||||
"image": ["cqtai"],
|
||||
"tts": ["minimax", "elevenlabs", "edge_tts"],
|
||||
"storybook": ["gemini"],
|
||||
}
|
||||
|
||||
# API Key 映射:adapter_name -> settings 属性名
|
||||
API_KEY_MAP: dict[str, str] = {
|
||||
# Text
|
||||
"gemini": "text_api_key", # Gemini 还是复用 text_api_key 字段
|
||||
"text_primary": "text_api_key", # 兼容旧别名
|
||||
"openai": "openai_api_key",
|
||||
|
||||
# Image
|
||||
"cqtai": "cqtai_api_key",
|
||||
"image_primary": "image_api_key", # 兼容旧别名
|
||||
|
||||
# TTS
|
||||
"minimax": "minimax_api_key",
|
||||
"elevenlabs": "elevenlabs_api_key",
|
||||
"edge_tts": "tts_api_key", # EdgeTTS 复用 tts_api_key (通常为空)
|
||||
"tts_primary": "tts_api_key", # 兼容旧别名
|
||||
}
|
||||
|
||||
# 轮询计数器
|
||||
_round_robin_counters: dict[ProviderType, int] = {
|
||||
"text": 0,
|
||||
"image": 0,
|
||||
"tts": 0,
|
||||
}
|
||||
|
||||
# 延迟缓存(内存中,简化实现)
|
||||
_latency_cache: dict[str, float] = {}
|
||||
|
||||
|
||||
def _get_api_key(config_ref: str | None, adapter_name: str) -> str:
|
||||
"""根据 config_ref 或适配器名称获取 API Key。"""
|
||||
# 优先使用 config_ref
|
||||
key_attr = API_KEY_MAP.get(config_ref or adapter_name, None)
|
||||
if key_attr:
|
||||
return getattr(settings, key_attr, "")
|
||||
# 回退到适配器名称
|
||||
key_attr = API_KEY_MAP.get(adapter_name, None)
|
||||
if key_attr:
|
||||
return getattr(settings, key_attr, "")
|
||||
return ""
|
||||
|
||||
|
||||
def _get_default_config(adapter_name: str) -> AdapterConfig | None:
|
||||
"""获取适配器的默认配置(无 DB 记录时使用)。返回 None 表示未知适配器。"""
|
||||
|
||||
# --- Text Defaults ---
|
||||
if adapter_name in ("gemini", "text_primary"):
|
||||
return AdapterConfig(
|
||||
api_key=settings.text_api_key,
|
||||
model=settings.text_model or "gemini-2.0-flash",
|
||||
timeout_ms=60000,
|
||||
)
|
||||
if adapter_name == "openai":
|
||||
return AdapterConfig(
|
||||
api_key=getattr(settings, "openai_api_key", ""),
|
||||
model="gpt-4o-mini", # 这里可以从 settings 读取,看需求
|
||||
timeout_ms=60000,
|
||||
)
|
||||
|
||||
# --- Image Defaults ---
|
||||
if adapter_name in ("cqtai"):
|
||||
return AdapterConfig(
|
||||
api_key=getattr(settings, "cqtai_api_key", ""),
|
||||
model="nano-banana-pro", # 默认使用 Pro
|
||||
timeout_ms=120000,
|
||||
)
|
||||
if adapter_name == "image_primary":
|
||||
# 如果还有地方在用 image_primary,暂时映射到快或者其他
|
||||
# 但既然我们全面整改,最好也删了。这里暂时保留一个空的 fallback 以防报错
|
||||
return AdapterConfig(
|
||||
api_key=settings.image_api_key,
|
||||
timeout_ms=120000
|
||||
)
|
||||
|
||||
# --- TTS Defaults ---
|
||||
if adapter_name == "minimax":
|
||||
# 传递 group_id 到 Adapter
|
||||
# 目前 AdapterConfig 没有 group_id 字段,我们暂时不改 Base,
|
||||
# 而是假设 Adapter 会从 config (通过 kwargs 或其他方式) 拿。
|
||||
# 实际上我们的 MiniMaxTTSAdapter 还没有处理 group_id。
|
||||
# 最简单的方法:把 group_id 藏在 api_base 里或者让 Adapter 自己去 settings 拿。
|
||||
# 鉴于 _build_config_from_provider 里我们无法传递额外参数给 Adapter.__init__,
|
||||
# 我们这里暂时返回基础配置。
|
||||
return AdapterConfig(
|
||||
api_key=getattr(settings, "minimax_api_key", ""),
|
||||
model="speech-2.6-turbo",
|
||||
timeout_ms=60000,
|
||||
)
|
||||
|
||||
if adapter_name == "elevenlabs":
|
||||
return AdapterConfig(
|
||||
api_key=getattr(settings, "elevenlabs_api_key", ""),
|
||||
timeout_ms=120000,
|
||||
)
|
||||
if adapter_name in ("edge_tts", "tts_primary"):
|
||||
return AdapterConfig(
|
||||
api_key=settings.tts_api_key,
|
||||
api_base=settings.tts_api_base,
|
||||
model=settings.tts_model or "zh-CN-XiaoxiaoNeural",
|
||||
timeout_ms=120000,
|
||||
)
|
||||
|
||||
# --- Others ---
|
||||
if adapter_name in ("storybook_primary", "storybook_gemini"):
|
||||
return AdapterConfig(
|
||||
api_key=settings.text_api_key, # 复用 Gemini key
|
||||
model=settings.text_model,
|
||||
timeout_ms=120000,
|
||||
)
|
||||
|
||||
# 未知适配器返回 None
|
||||
return None
|
||||
|
||||
|
||||
def _build_config_from_provider(provider: "Provider") -> AdapterConfig:
|
||||
"""从 DB Provider 记录构建 AdapterConfig。"""
|
||||
api_key = getattr(provider, "api_key", None) or ""
|
||||
if not api_key:
|
||||
api_key = _get_api_key(provider.config_ref, provider.adapter)
|
||||
|
||||
default = _get_default_config(provider.adapter)
|
||||
if default is None:
|
||||
default = AdapterConfig(api_key="", timeout_ms=60000)
|
||||
|
||||
return AdapterConfig(
|
||||
api_key=api_key or default.api_key,
|
||||
api_base=provider.api_base or default.api_base,
|
||||
model=provider.model or default.model,
|
||||
timeout_ms=provider.timeout_ms or default.timeout_ms,
|
||||
max_retries=provider.max_retries or default.max_retries,
|
||||
extra_config=provider.config_json or {},
|
||||
)
|
||||
|
||||
|
||||
def _get_providers_with_config(
|
||||
provider_type: ProviderType,
|
||||
) -> list[tuple[str, AdapterConfig, "Provider | None"]]:
|
||||
"""获取供应商列表及其配置。
|
||||
|
||||
Returns:
|
||||
[(adapter_name, config, provider_or_none), ...] 按优先级排序
|
||||
"""
|
||||
db_providers = get_providers(provider_type)
|
||||
|
||||
if db_providers:
|
||||
return [(p.adapter, _build_config_from_provider(p), p) for p in db_providers]
|
||||
|
||||
settings_map = {
|
||||
"text": settings.text_providers,
|
||||
"image": settings.image_providers,
|
||||
"tts": settings.tts_providers,
|
||||
}
|
||||
names = settings_map.get(provider_type) or DEFAULT_PROVIDERS[provider_type]
|
||||
result = []
|
||||
for name in names:
|
||||
config = _get_default_config(name)
|
||||
if config is None:
|
||||
logger.warning("unknown_adapter_skipped", adapter=name, provider_type=provider_type)
|
||||
continue
|
||||
result.append((name, config, None))
|
||||
return result
|
||||
|
||||
|
||||
def _sort_by_strategy(
|
||||
providers: list[tuple[str, AdapterConfig, "Provider | None"]],
|
||||
strategy: RoutingStrategy,
|
||||
provider_type: ProviderType,
|
||||
) -> list[tuple[str, AdapterConfig, "Provider | None"]]:
|
||||
"""按策略排序供应商列表。"""
|
||||
if strategy == RoutingStrategy.PRIORITY:
|
||||
# 按 priority 降序, weight 降序
|
||||
return sorted(
|
||||
providers,
|
||||
key=lambda x: (-(x[2].priority if x[2] else 0), -(x[2].weight if x[2] else 1)),
|
||||
)
|
||||
|
||||
if strategy == RoutingStrategy.COST:
|
||||
# 按预估成本升序
|
||||
def get_cost(item: tuple[str, AdapterConfig, "Provider | None"]) -> float:
|
||||
adapter_class = AdapterRegistry.get(provider_type, item[0])
|
||||
if adapter_class:
|
||||
try:
|
||||
adapter = adapter_class(item[1])
|
||||
return adapter.estimated_cost
|
||||
except Exception:
|
||||
pass
|
||||
return float("inf")
|
||||
|
||||
return sorted(providers, key=get_cost)
|
||||
|
||||
if strategy == RoutingStrategy.LATENCY:
|
||||
# 按历史延迟升序
|
||||
def get_latency(item: tuple[str, AdapterConfig, "Provider | None"]) -> float:
|
||||
return _latency_cache.get(item[0], float("inf"))
|
||||
|
||||
return sorted(providers, key=get_latency)
|
||||
|
||||
if strategy == RoutingStrategy.ROUND_ROBIN:
|
||||
# 轮询:旋转列表
|
||||
counter = _round_robin_counters[provider_type]
|
||||
_round_robin_counters[provider_type] = (counter + 1) % max(len(providers), 1)
|
||||
return providers[counter:] + providers[:counter]
|
||||
|
||||
return providers
|
||||
|
||||
|
||||
async def _route_with_failover(
|
||||
provider_type: ProviderType,
|
||||
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
|
||||
db: AsyncSession | None = None,
|
||||
user_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
"""通用 provider failover 路由。
|
||||
|
||||
Args:
|
||||
provider_type: 供应商类型 (text/image/tts/storybook)
|
||||
strategy: 路由策略
|
||||
db: 数据库会话(可选,用于指标收集和熔断检查)
|
||||
user_id: 用户 ID(可选,用于成本追踪和预算检查)
|
||||
**kwargs: 传递给适配器的参数
|
||||
"""
|
||||
providers = _get_providers_with_config(provider_type)
|
||||
|
||||
if not providers:
|
||||
raise ValueError(f"No {provider_type} providers configured.")
|
||||
|
||||
# 按策略排序
|
||||
sorted_providers = _sort_by_strategy(providers, strategy, provider_type)
|
||||
|
||||
# 如果有 db 会话,过滤掉熔断的供应商
|
||||
if db:
|
||||
healthy_providers = []
|
||||
for item in sorted_providers:
|
||||
name, config, db_provider = item
|
||||
provider_id = db_provider.id if db_provider else name
|
||||
if await health_checker.is_healthy(db, provider_id):
|
||||
healthy_providers.append(item)
|
||||
else:
|
||||
logger.debug("provider_circuit_open", adapter=name, provider_id=provider_id)
|
||||
# 如果所有供应商都熔断,仍然尝试第一个(允许恢复)
|
||||
if not healthy_providers:
|
||||
healthy_providers = sorted_providers[:1]
|
||||
sorted_providers = healthy_providers
|
||||
|
||||
errors: list[str] = []
|
||||
for name, config, db_provider in sorted_providers:
|
||||
adapter_class = AdapterRegistry.get(provider_type, name)
|
||||
if not adapter_class:
|
||||
errors.append(f"{name}: 适配器未注册")
|
||||
continue
|
||||
|
||||
provider_id = db_provider.id if db_provider else name
|
||||
|
||||
try:
|
||||
logger.debug(
|
||||
"provider_attempt",
|
||||
provider_type=provider_type,
|
||||
adapter=name,
|
||||
strategy=strategy.value,
|
||||
)
|
||||
|
||||
adapter = adapter_class(config)
|
||||
|
||||
# 执行并计时
|
||||
start_time = time.time()
|
||||
result = await adapter.execute(**kwargs)
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# 更新延迟缓存
|
||||
_latency_cache[name] = latency_ms
|
||||
|
||||
# 记录成功指标
|
||||
if db:
|
||||
await metrics_collector.record_call(
|
||||
db,
|
||||
provider_id=provider_id,
|
||||
success=True,
|
||||
latency_ms=latency_ms,
|
||||
cost_usd=adapter.estimated_cost,
|
||||
)
|
||||
await health_checker.record_call_result(db, provider_id, success=True)
|
||||
|
||||
# 记录用户成本
|
||||
if user_id:
|
||||
await cost_tracker.record_cost(
|
||||
db,
|
||||
user_id=user_id,
|
||||
provider_name=name,
|
||||
capability=provider_type,
|
||||
estimated_cost=adapter.estimated_cost,
|
||||
provider_id=provider_id if db_provider else None,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"provider_success",
|
||||
provider_type=provider_type,
|
||||
adapter=name,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
return result
|
||||
|
||||
except Exception as exc:
|
||||
error_msg = str(exc)
|
||||
logger.warning(
|
||||
"provider_failed",
|
||||
provider_type=provider_type,
|
||||
adapter=name,
|
||||
error=error_msg,
|
||||
)
|
||||
errors.append(f"{name}: {exc}")
|
||||
|
||||
# 记录失败指标
|
||||
if db:
|
||||
await metrics_collector.record_call(
|
||||
db,
|
||||
provider_id=provider_id,
|
||||
success=False,
|
||||
error_message=error_msg,
|
||||
)
|
||||
await health_checker.record_call_result(
|
||||
db, provider_id, success=False, error=error_msg
|
||||
)
|
||||
|
||||
raise ValueError(f"No {provider_type} provider succeeded. Errors: {' | '.join(errors)}")
|
||||
|
||||
|
||||
async def generate_story_content(
|
||||
input_type: Literal["keywords", "full_story"],
|
||||
data: str,
|
||||
education_theme: str | None = None,
|
||||
memory_context: str | None = None,
|
||||
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
|
||||
db: AsyncSession | None = None,
|
||||
) -> StoryOutput:
|
||||
"""生成或润色故事,支持 failover。"""
|
||||
return await _route_with_failover(
|
||||
"text",
|
||||
strategy=strategy,
|
||||
db=db,
|
||||
input_type=input_type,
|
||||
data=data,
|
||||
education_theme=education_theme,
|
||||
memory_context=memory_context,
|
||||
)
|
||||
|
||||
|
||||
async def generate_image(
|
||||
prompt: str,
|
||||
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
|
||||
db: AsyncSession | None = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""生成图片,返回 URL,支持 failover。"""
|
||||
return await _route_with_failover("image", strategy=strategy, db=db, prompt=prompt, **kwargs)
|
||||
|
||||
|
||||
async def text_to_speech(
|
||||
text: str,
|
||||
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
|
||||
db: AsyncSession | None = None,
|
||||
) -> bytes:
|
||||
"""文本转语音,返回 MP3 bytes,支持 failover。"""
|
||||
return await _route_with_failover("tts", strategy=strategy, db=db, text=text)
|
||||
|
||||
|
||||
async def generate_storybook(
|
||||
keywords: str,
|
||||
page_count: int = 6,
|
||||
education_theme: str | None = None,
|
||||
memory_context: str | None = None,
|
||||
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
|
||||
db: AsyncSession | None = None,
|
||||
):
|
||||
"""生成分页故事书,支持 failover。"""
|
||||
from app.services.adapters.storybook.primary import Storybook
|
||||
|
||||
result: Storybook = await _route_with_failover(
|
||||
"storybook",
|
||||
strategy=strategy,
|
||||
db=db,
|
||||
keywords=keywords,
|
||||
page_count=page_count,
|
||||
education_theme=education_theme,
|
||||
memory_context=memory_context,
|
||||
)
|
||||
return result
|
||||
Reference in New Issue
Block a user