refactor: separate provider capability policy
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
"""Provider routing with failover - 基于适配器注册表的智能路由。"""
|
||||
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Literal, TypeVar
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -13,6 +12,13 @@ from app.services.adapters.text.models import StoryOutput
|
||||
from app.services.cost_tracker import cost_tracker
|
||||
from app.services.provider_cache import get_providers
|
||||
from app.services.provider_metrics import health_checker, metrics_collector
|
||||
from app.services.provider_policy import (
|
||||
API_KEY_MAP,
|
||||
DEFAULT_PROVIDERS,
|
||||
ProviderType,
|
||||
RoutingStrategy,
|
||||
get_provider_names_from_settings,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.db.admin_models import Provider
|
||||
@@ -21,50 +27,9 @@ logger = get_logger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
ProviderType = Literal["text", "image", "tts", "storybook"]
|
||||
|
||||
|
||||
class RoutingStrategy(str, Enum):
|
||||
"""路由策略枚举。"""
|
||||
|
||||
PRIORITY = "priority" # 按优先级排序(默认)
|
||||
COST = "cost" # 按成本排序
|
||||
LATENCY = "latency" # 按延迟排序
|
||||
ROUND_ROBIN = "round_robin" # 轮询
|
||||
|
||||
|
||||
# 默认配置映射(当 DB 无配置时使用)
|
||||
# 这是“代码级”的默认策略,对应 .env 为空的情况
|
||||
DEFAULT_PROVIDERS: dict[ProviderType, list[str]] = {
|
||||
"text": ["gemini", "openai"],
|
||||
"image": ["cqtai"],
|
||||
"tts": ["minimax", "elevenlabs", "edge_tts"],
|
||||
"storybook": ["storybook_primary"],
|
||||
}
|
||||
|
||||
# API Key 映射:adapter_name -> settings 属性名
|
||||
API_KEY_MAP: dict[str, str] = {
|
||||
# Text
|
||||
"gemini": "text_api_key", # Gemini 还是复用 text_api_key 字段
|
||||
"text_primary": "text_api_key", # 兼容旧别名
|
||||
"openai": "openai_api_key",
|
||||
|
||||
# Image
|
||||
"cqtai": "cqtai_api_key",
|
||||
"image_primary": "image_api_key", # 兼容旧别名
|
||||
|
||||
# TTS
|
||||
"minimax": "minimax_api_key",
|
||||
"elevenlabs": "elevenlabs_api_key",
|
||||
"edge_tts": "tts_api_key", # EdgeTTS 复用 tts_api_key (通常为空)
|
||||
"tts_primary": "tts_api_key", # 兼容旧别名
|
||||
}
|
||||
|
||||
# 轮询计数器
|
||||
_round_robin_counters: dict[ProviderType, int] = {
|
||||
"text": 0,
|
||||
"image": 0,
|
||||
"tts": 0,
|
||||
provider_type: 0 for provider_type in DEFAULT_PROVIDERS
|
||||
}
|
||||
|
||||
# 延迟缓存(内存中,简化实现)
|
||||
@@ -115,6 +80,13 @@ def _get_default_config(adapter_name: str) -> AdapterConfig | None:
|
||||
model=settings.image_model or "nano-banana-pro",
|
||||
timeout_ms=120000,
|
||||
)
|
||||
if adapter_name == "antigravity":
|
||||
return AdapterConfig(
|
||||
api_key=getattr(settings, "antigravity_api_key", ""),
|
||||
api_base=getattr(settings, "antigravity_api_base", ""),
|
||||
model=settings.antigravity_model,
|
||||
timeout_ms=120000,
|
||||
)
|
||||
if adapter_name == "image_primary":
|
||||
# 如果还有地方在用 image_primary,暂时映射到快或者其他
|
||||
# 但既然我们全面整改,最好也删了。这里暂时保留一个空的 fallback 以防报错
|
||||
@@ -196,15 +168,7 @@ async def _get_providers_with_config(
|
||||
if db_providers:
|
||||
return [(p.adapter, _build_config_from_provider(p), p) for p in db_providers]
|
||||
|
||||
settings_map = {
|
||||
"text": settings.text_providers,
|
||||
"image": settings.image_providers,
|
||||
"tts": settings.tts_providers,
|
||||
"storybook": settings.storybook_providers,
|
||||
}
|
||||
names = settings_map.get(provider_type) or DEFAULT_PROVIDERS[provider_type]
|
||||
if settings.enable_demo_providers and "demo" not in names:
|
||||
names = ["demo", *names]
|
||||
names = get_provider_names_from_settings(provider_type, settings)
|
||||
|
||||
result = []
|
||||
for name in names:
|
||||
|
||||
Reference in New Issue
Block a user