"""供应商指标收集和健康检查服务。""" from datetime import datetime, timedelta from decimal import Decimal from typing import TYPE_CHECKING from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from app.core.logging import get_logger from app.db.admin_models import ProviderHealth, ProviderMetrics if TYPE_CHECKING: from app.services.adapters.base import BaseAdapter logger = get_logger(__name__) # 熔断阈值:连续失败次数 CIRCUIT_BREAKER_THRESHOLD = 3 # 熔断恢复时间(秒) CIRCUIT_BREAKER_RECOVERY_SECONDS = 60 class MetricsCollector: """供应商调用指标收集器。""" async def record_call( self, db: AsyncSession, provider_id: str, success: bool, latency_ms: int | None = None, cost_usd: float | None = None, error_message: str | None = None, request_id: str | None = None, ) -> None: """记录一次 API 调用。""" metric = ProviderMetrics( provider_id=provider_id, success=success, latency_ms=latency_ms, cost_usd=Decimal(str(cost_usd)) if cost_usd else None, error_message=error_message, request_id=request_id, ) db.add(metric) await db.commit() logger.debug( "metrics_recorded", provider_id=provider_id, success=success, latency_ms=latency_ms, ) async def get_success_rate( self, db: AsyncSession, provider_id: str, window_minutes: int = 60, ) -> float: """获取指定时间窗口内的成功率。""" since = datetime.utcnow() - timedelta(minutes=window_minutes) result = await db.execute( select( func.count().filter(ProviderMetrics.success.is_(True)).label("success_count"), func.count().label("total_count"), ).where( ProviderMetrics.provider_id == provider_id, ProviderMetrics.timestamp >= since, ) ) row = result.one() success_count, total_count = row.success_count, row.total_count if total_count == 0: return 1.0 # 无数据时假设健康 return success_count / total_count async def get_avg_latency( self, db: AsyncSession, provider_id: str, window_minutes: int = 60, ) -> float: """获取指定时间窗口内的平均延迟(毫秒)。""" since = datetime.utcnow() - timedelta(minutes=window_minutes) result = await db.execute( select(func.avg(ProviderMetrics.latency_ms)).where( ProviderMetrics.provider_id == provider_id, ProviderMetrics.timestamp >= since, ProviderMetrics.latency_ms.isnot(None), ) ) avg = result.scalar() return float(avg) if avg else 0.0 async def get_total_cost( self, db: AsyncSession, provider_id: str, window_minutes: int = 60, ) -> float: """获取指定时间窗口内的总成本(USD)。""" since = datetime.utcnow() - timedelta(minutes=window_minutes) result = await db.execute( select(func.sum(ProviderMetrics.cost_usd)).where( ProviderMetrics.provider_id == provider_id, ProviderMetrics.timestamp >= since, ) ) total = result.scalar() return float(total) if total else 0.0 class HealthChecker: """供应商健康检查器。""" async def check_provider( self, db: AsyncSession, provider_id: str, adapter: "BaseAdapter", ) -> bool: """执行健康检查并更新状态。""" try: is_healthy = await adapter.health_check() except Exception as e: logger.warning("health_check_failed", provider_id=provider_id, error=str(e)) is_healthy = False await self.update_health_status( db, provider_id, is_healthy, error=None if is_healthy else "Health check failed", ) return is_healthy async def update_health_status( self, db: AsyncSession, provider_id: str, is_healthy: bool, error: str | None = None, ) -> None: """更新供应商健康状态(含熔断逻辑)。""" result = await db.execute( select(ProviderHealth).where(ProviderHealth.provider_id == provider_id) ) health = result.scalar_one_or_none() now = datetime.utcnow() if health is None: health = ProviderHealth( provider_id=provider_id, is_healthy=is_healthy, last_check=now, consecutive_failures=0 if is_healthy else 1, last_error=error, ) db.add(health) else: health.last_check = now if is_healthy: health.is_healthy = True health.consecutive_failures = 0 health.last_error = None else: health.consecutive_failures += 1 health.last_error = error # 熔断逻辑 if health.consecutive_failures >= CIRCUIT_BREAKER_THRESHOLD: health.is_healthy = False logger.warning( "circuit_breaker_triggered", provider_id=provider_id, consecutive_failures=health.consecutive_failures, ) await db.commit() async def record_call_result( self, db: AsyncSession, provider_id: str, success: bool, error: str | None = None, ) -> None: """根据调用结果更新健康状态。""" await self.update_health_status(db, provider_id, success, error) async def get_healthy_providers( self, db: AsyncSession, provider_ids: list[str], ) -> list[str]: """获取健康的供应商列表。""" if not provider_ids: return [] # 查询所有已记录的健康状态 result = await db.execute( select(ProviderHealth.provider_id, ProviderHealth.is_healthy).where( ProviderHealth.provider_id.in_(provider_ids), ) ) health_map = {row[0]: row[1] for row in result.all()} # 未记录的供应商默认健康,已记录但不健康的排除 return [ pid for pid in provider_ids if pid not in health_map or health_map[pid] ] async def is_healthy( self, db: AsyncSession, provider_id: str, ) -> bool: """检查供应商是否健康。""" result = await db.execute( select(ProviderHealth).where(ProviderHealth.provider_id == provider_id) ) health = result.scalar_one_or_none() if health is None: return True # 未记录默认健康 # 检查是否可以恢复 if not health.is_healthy and health.last_check: recovery_time = health.last_check + timedelta(seconds=CIRCUIT_BREAKER_RECOVERY_SECONDS) if datetime.utcnow() >= recovery_time: return True # 允许重试 return health.is_healthy # 全局单例 metrics_collector = MetricsCollector() health_checker = HealthChecker()