Initial commit: clean project structure

- Backend: FastAPI + SQLAlchemy + Celery (Python 3.11+)
- Frontend: Vue 3 + TypeScript + Pinia + Tailwind
- Admin Frontend: separate Vue 3 app for management
- Docker Compose: 9 services orchestration
- Specs: design prototypes, memory system PRD, product roadmap

Cleanup performed:
- Removed temporary debug scripts from backend root
- Removed deprecated admin_app.py (embedded UI)
- Removed duplicate docs from admin-frontend
- Updated .gitignore for Vite cache and egg-info
This commit is contained in:
zhangtuo
2026-01-20 18:20:03 +08:00
commit e9d7f8832a
241 changed files with 33070 additions and 0 deletions

View File

@@ -0,0 +1,432 @@
"""Provider routing with failover - 基于适配器注册表的智能路由。"""
import time
from enum import Enum
from typing import TYPE_CHECKING, Literal, TypeVar
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
from app.services.adapters import AdapterConfig, AdapterRegistry
from app.services.adapters.text.models import StoryOutput
from app.services.cost_tracker import cost_tracker
from app.services.provider_cache import get_providers
from app.services.provider_metrics import health_checker, metrics_collector
if TYPE_CHECKING:
from app.db.admin_models import Provider
logger = get_logger(__name__)
T = TypeVar("T")
ProviderType = Literal["text", "image", "tts", "storybook"]
class RoutingStrategy(str, Enum):
"""路由策略枚举。"""
PRIORITY = "priority" # 按优先级排序(默认)
COST = "cost" # 按成本排序
LATENCY = "latency" # 按延迟排序
ROUND_ROBIN = "round_robin" # 轮询
# 默认配置映射(当 DB 无配置时使用)
# 默认配置映射(当 DB 无配置时使用)
# 这是“代码级”的默认策略,对应 .env 为空的情况
DEFAULT_PROVIDERS: dict[ProviderType, list[str]] = {
"text": ["gemini", "openai"],
"image": ["cqtai"],
"tts": ["minimax", "elevenlabs", "edge_tts"],
"storybook": ["gemini"],
}
# API Key 映射adapter_name -> settings 属性名
API_KEY_MAP: dict[str, str] = {
# Text
"gemini": "text_api_key", # Gemini 还是复用 text_api_key 字段
"text_primary": "text_api_key", # 兼容旧别名
"openai": "openai_api_key",
# Image
"cqtai": "cqtai_api_key",
"image_primary": "image_api_key", # 兼容旧别名
# TTS
"minimax": "minimax_api_key",
"elevenlabs": "elevenlabs_api_key",
"edge_tts": "tts_api_key", # EdgeTTS 复用 tts_api_key (通常为空)
"tts_primary": "tts_api_key", # 兼容旧别名
}
# 轮询计数器
_round_robin_counters: dict[ProviderType, int] = {
"text": 0,
"image": 0,
"tts": 0,
}
# 延迟缓存(内存中,简化实现)
_latency_cache: dict[str, float] = {}
def _get_api_key(config_ref: str | None, adapter_name: str) -> str:
"""根据 config_ref 或适配器名称获取 API Key。"""
# 优先使用 config_ref
key_attr = API_KEY_MAP.get(config_ref or adapter_name, None)
if key_attr:
return getattr(settings, key_attr, "")
# 回退到适配器名称
key_attr = API_KEY_MAP.get(adapter_name, None)
if key_attr:
return getattr(settings, key_attr, "")
return ""
def _get_default_config(adapter_name: str) -> AdapterConfig | None:
"""获取适配器的默认配置(无 DB 记录时使用)。返回 None 表示未知适配器。"""
# --- Text Defaults ---
if adapter_name in ("gemini", "text_primary"):
return AdapterConfig(
api_key=settings.text_api_key,
model=settings.text_model or "gemini-2.0-flash",
timeout_ms=60000,
)
if adapter_name == "openai":
return AdapterConfig(
api_key=getattr(settings, "openai_api_key", ""),
model="gpt-4o-mini", # 这里可以从 settings 读取,看需求
timeout_ms=60000,
)
# --- Image Defaults ---
if adapter_name in ("cqtai"):
return AdapterConfig(
api_key=getattr(settings, "cqtai_api_key", ""),
model="nano-banana-pro", # 默认使用 Pro
timeout_ms=120000,
)
if adapter_name == "image_primary":
# 如果还有地方在用 image_primary暂时映射到快或者其他
# 但既然我们全面整改,最好也删了。这里暂时保留一个空的 fallback 以防报错
return AdapterConfig(
api_key=settings.image_api_key,
timeout_ms=120000
)
# --- TTS Defaults ---
if adapter_name == "minimax":
# 传递 group_id 到 Adapter
# 目前 AdapterConfig 没有 group_id 字段,我们暂时不改 Base
# 而是假设 Adapter 会从 config (通过 kwargs 或其他方式) 拿。
# 实际上我们的 MiniMaxTTSAdapter 还没有处理 group_id。
# 最简单的方法:把 group_id 藏在 api_base 里或者让 Adapter 自己去 settings 拿。
# 鉴于 _build_config_from_provider 里我们无法传递额外参数给 Adapter.__init__
# 我们这里暂时返回基础配置。
return AdapterConfig(
api_key=getattr(settings, "minimax_api_key", ""),
model="speech-2.6-turbo",
timeout_ms=60000,
)
if adapter_name == "elevenlabs":
return AdapterConfig(
api_key=getattr(settings, "elevenlabs_api_key", ""),
timeout_ms=120000,
)
if adapter_name in ("edge_tts", "tts_primary"):
return AdapterConfig(
api_key=settings.tts_api_key,
api_base=settings.tts_api_base,
model=settings.tts_model or "zh-CN-XiaoxiaoNeural",
timeout_ms=120000,
)
# --- Others ---
if adapter_name in ("storybook_primary", "storybook_gemini"):
return AdapterConfig(
api_key=settings.text_api_key, # 复用 Gemini key
model=settings.text_model,
timeout_ms=120000,
)
# 未知适配器返回 None
return None
def _build_config_from_provider(provider: "Provider") -> AdapterConfig:
"""从 DB Provider 记录构建 AdapterConfig。"""
api_key = getattr(provider, "api_key", None) or ""
if not api_key:
api_key = _get_api_key(provider.config_ref, provider.adapter)
default = _get_default_config(provider.adapter)
if default is None:
default = AdapterConfig(api_key="", timeout_ms=60000)
return AdapterConfig(
api_key=api_key or default.api_key,
api_base=provider.api_base or default.api_base,
model=provider.model or default.model,
timeout_ms=provider.timeout_ms or default.timeout_ms,
max_retries=provider.max_retries or default.max_retries,
extra_config=provider.config_json or {},
)
def _get_providers_with_config(
provider_type: ProviderType,
) -> list[tuple[str, AdapterConfig, "Provider | None"]]:
"""获取供应商列表及其配置。
Returns:
[(adapter_name, config, provider_or_none), ...] 按优先级排序
"""
db_providers = get_providers(provider_type)
if db_providers:
return [(p.adapter, _build_config_from_provider(p), p) for p in db_providers]
settings_map = {
"text": settings.text_providers,
"image": settings.image_providers,
"tts": settings.tts_providers,
}
names = settings_map.get(provider_type) or DEFAULT_PROVIDERS[provider_type]
result = []
for name in names:
config = _get_default_config(name)
if config is None:
logger.warning("unknown_adapter_skipped", adapter=name, provider_type=provider_type)
continue
result.append((name, config, None))
return result
def _sort_by_strategy(
providers: list[tuple[str, AdapterConfig, "Provider | None"]],
strategy: RoutingStrategy,
provider_type: ProviderType,
) -> list[tuple[str, AdapterConfig, "Provider | None"]]:
"""按策略排序供应商列表。"""
if strategy == RoutingStrategy.PRIORITY:
# 按 priority 降序, weight 降序
return sorted(
providers,
key=lambda x: (-(x[2].priority if x[2] else 0), -(x[2].weight if x[2] else 1)),
)
if strategy == RoutingStrategy.COST:
# 按预估成本升序
def get_cost(item: tuple[str, AdapterConfig, "Provider | None"]) -> float:
adapter_class = AdapterRegistry.get(provider_type, item[0])
if adapter_class:
try:
adapter = adapter_class(item[1])
return adapter.estimated_cost
except Exception:
pass
return float("inf")
return sorted(providers, key=get_cost)
if strategy == RoutingStrategy.LATENCY:
# 按历史延迟升序
def get_latency(item: tuple[str, AdapterConfig, "Provider | None"]) -> float:
return _latency_cache.get(item[0], float("inf"))
return sorted(providers, key=get_latency)
if strategy == RoutingStrategy.ROUND_ROBIN:
# 轮询:旋转列表
counter = _round_robin_counters[provider_type]
_round_robin_counters[provider_type] = (counter + 1) % max(len(providers), 1)
return providers[counter:] + providers[:counter]
return providers
async def _route_with_failover(
provider_type: ProviderType,
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
user_id: str | None = None,
**kwargs,
) -> T:
"""通用 provider failover 路由。
Args:
provider_type: 供应商类型 (text/image/tts/storybook)
strategy: 路由策略
db: 数据库会话(可选,用于指标收集和熔断检查)
user_id: 用户 ID可选用于成本追踪和预算检查
**kwargs: 传递给适配器的参数
"""
providers = _get_providers_with_config(provider_type)
if not providers:
raise ValueError(f"No {provider_type} providers configured.")
# 按策略排序
sorted_providers = _sort_by_strategy(providers, strategy, provider_type)
# 如果有 db 会话,过滤掉熔断的供应商
if db:
healthy_providers = []
for item in sorted_providers:
name, config, db_provider = item
provider_id = db_provider.id if db_provider else name
if await health_checker.is_healthy(db, provider_id):
healthy_providers.append(item)
else:
logger.debug("provider_circuit_open", adapter=name, provider_id=provider_id)
# 如果所有供应商都熔断,仍然尝试第一个(允许恢复)
if not healthy_providers:
healthy_providers = sorted_providers[:1]
sorted_providers = healthy_providers
errors: list[str] = []
for name, config, db_provider in sorted_providers:
adapter_class = AdapterRegistry.get(provider_type, name)
if not adapter_class:
errors.append(f"{name}: 适配器未注册")
continue
provider_id = db_provider.id if db_provider else name
try:
logger.debug(
"provider_attempt",
provider_type=provider_type,
adapter=name,
strategy=strategy.value,
)
adapter = adapter_class(config)
# 执行并计时
start_time = time.time()
result = await adapter.execute(**kwargs)
latency_ms = int((time.time() - start_time) * 1000)
# 更新延迟缓存
_latency_cache[name] = latency_ms
# 记录成功指标
if db:
await metrics_collector.record_call(
db,
provider_id=provider_id,
success=True,
latency_ms=latency_ms,
cost_usd=adapter.estimated_cost,
)
await health_checker.record_call_result(db, provider_id, success=True)
# 记录用户成本
if user_id:
await cost_tracker.record_cost(
db,
user_id=user_id,
provider_name=name,
capability=provider_type,
estimated_cost=adapter.estimated_cost,
provider_id=provider_id if db_provider else None,
)
logger.info(
"provider_success",
provider_type=provider_type,
adapter=name,
latency_ms=latency_ms,
)
return result
except Exception as exc:
error_msg = str(exc)
logger.warning(
"provider_failed",
provider_type=provider_type,
adapter=name,
error=error_msg,
)
errors.append(f"{name}: {exc}")
# 记录失败指标
if db:
await metrics_collector.record_call(
db,
provider_id=provider_id,
success=False,
error_message=error_msg,
)
await health_checker.record_call_result(
db, provider_id, success=False, error=error_msg
)
raise ValueError(f"No {provider_type} provider succeeded. Errors: {' | '.join(errors)}")
async def generate_story_content(
input_type: Literal["keywords", "full_story"],
data: str,
education_theme: str | None = None,
memory_context: str | None = None,
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
) -> StoryOutput:
"""生成或润色故事,支持 failover。"""
return await _route_with_failover(
"text",
strategy=strategy,
db=db,
input_type=input_type,
data=data,
education_theme=education_theme,
memory_context=memory_context,
)
async def generate_image(
prompt: str,
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
**kwargs,
) -> str:
"""生成图片,返回 URL支持 failover。"""
return await _route_with_failover("image", strategy=strategy, db=db, prompt=prompt, **kwargs)
async def text_to_speech(
text: str,
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
) -> bytes:
"""文本转语音,返回 MP3 bytes支持 failover。"""
return await _route_with_failover("tts", strategy=strategy, db=db, text=text)
async def generate_storybook(
keywords: str,
page_count: int = 6,
education_theme: str | None = None,
memory_context: str | None = None,
strategy: RoutingStrategy = RoutingStrategy.PRIORITY,
db: AsyncSession | None = None,
):
"""生成分页故事书,支持 failover。"""
from app.services.adapters.storybook.primary import Storybook
result: Storybook = await _route_with_failover(
"storybook",
strategy=strategy,
db=db,
keywords=keywords,
page_count=page_count,
education_theme=education_theme,
memory_context=memory_context,
)
return result