Some checks are pending
Build and Push Docker Images / changes (push) Waiting to run
Build and Push Docker Images / build-backend (push) Blocked by required conditions
Build and Push Docker Images / build-frontend (push) Blocked by required conditions
Build and Push Docker Images / build-admin-frontend (push) Blocked by required conditions
249 lines
7.7 KiB
Python
249 lines
7.7 KiB
Python
"""供应商指标收集和健康检查服务。"""
|
||
|
||
from datetime import datetime, timedelta
|
||
from decimal import Decimal
|
||
from typing import TYPE_CHECKING
|
||
|
||
from sqlalchemy import func, select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.core.logging import get_logger
|
||
from app.db.admin_models import ProviderHealth, ProviderMetrics
|
||
|
||
if TYPE_CHECKING:
|
||
from app.services.adapters.base import BaseAdapter
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
# 熔断阈值:连续失败次数
|
||
CIRCUIT_BREAKER_THRESHOLD = 3
|
||
# 熔断恢复时间(秒)
|
||
CIRCUIT_BREAKER_RECOVERY_SECONDS = 60
|
||
|
||
|
||
class MetricsCollector:
|
||
"""供应商调用指标收集器。"""
|
||
|
||
async def record_call(
|
||
self,
|
||
db: AsyncSession,
|
||
provider_id: str,
|
||
success: bool,
|
||
latency_ms: int | None = None,
|
||
cost_usd: float | None = None,
|
||
error_message: str | None = None,
|
||
request_id: str | None = None,
|
||
) -> None:
|
||
"""记录一次 API 调用。"""
|
||
metric = ProviderMetrics(
|
||
provider_id=provider_id,
|
||
success=success,
|
||
latency_ms=latency_ms,
|
||
cost_usd=Decimal(str(cost_usd)) if cost_usd else None,
|
||
error_message=error_message,
|
||
request_id=request_id,
|
||
)
|
||
db.add(metric)
|
||
await db.commit()
|
||
|
||
logger.debug(
|
||
"metrics_recorded",
|
||
provider_id=provider_id,
|
||
success=success,
|
||
latency_ms=latency_ms,
|
||
)
|
||
|
||
async def get_success_rate(
|
||
self,
|
||
db: AsyncSession,
|
||
provider_id: str,
|
||
window_minutes: int = 60,
|
||
) -> float:
|
||
"""获取指定时间窗口内的成功率。"""
|
||
since = datetime.utcnow() - timedelta(minutes=window_minutes)
|
||
|
||
result = await db.execute(
|
||
select(
|
||
func.count().filter(ProviderMetrics.success.is_(True)).label("success_count"),
|
||
func.count().label("total_count"),
|
||
).where(
|
||
ProviderMetrics.provider_id == provider_id,
|
||
ProviderMetrics.timestamp >= since,
|
||
)
|
||
)
|
||
row = result.one()
|
||
success_count, total_count = row.success_count, row.total_count
|
||
|
||
if total_count == 0:
|
||
return 1.0 # 无数据时假设健康
|
||
|
||
return success_count / total_count
|
||
|
||
async def get_avg_latency(
|
||
self,
|
||
db: AsyncSession,
|
||
provider_id: str,
|
||
window_minutes: int = 60,
|
||
) -> float:
|
||
"""获取指定时间窗口内的平均延迟(毫秒)。"""
|
||
since = datetime.utcnow() - timedelta(minutes=window_minutes)
|
||
|
||
result = await db.execute(
|
||
select(func.avg(ProviderMetrics.latency_ms)).where(
|
||
ProviderMetrics.provider_id == provider_id,
|
||
ProviderMetrics.timestamp >= since,
|
||
ProviderMetrics.latency_ms.isnot(None),
|
||
)
|
||
)
|
||
avg = result.scalar()
|
||
return float(avg) if avg else 0.0
|
||
|
||
async def get_total_cost(
|
||
self,
|
||
db: AsyncSession,
|
||
provider_id: str,
|
||
window_minutes: int = 60,
|
||
) -> float:
|
||
"""获取指定时间窗口内的总成本(USD)。"""
|
||
since = datetime.utcnow() - timedelta(minutes=window_minutes)
|
||
|
||
result = await db.execute(
|
||
select(func.sum(ProviderMetrics.cost_usd)).where(
|
||
ProviderMetrics.provider_id == provider_id,
|
||
ProviderMetrics.timestamp >= since,
|
||
)
|
||
)
|
||
total = result.scalar()
|
||
return float(total) if total else 0.0
|
||
|
||
|
||
class HealthChecker:
|
||
"""供应商健康检查器。"""
|
||
|
||
async def check_provider(
|
||
self,
|
||
db: AsyncSession,
|
||
provider_id: str,
|
||
adapter: "BaseAdapter",
|
||
) -> bool:
|
||
"""执行健康检查并更新状态。"""
|
||
try:
|
||
is_healthy = await adapter.health_check()
|
||
except Exception as e:
|
||
logger.warning("health_check_failed", provider_id=provider_id, error=str(e))
|
||
is_healthy = False
|
||
|
||
await self.update_health_status(
|
||
db,
|
||
provider_id,
|
||
is_healthy,
|
||
error=None if is_healthy else "Health check failed",
|
||
)
|
||
return is_healthy
|
||
|
||
async def update_health_status(
|
||
self,
|
||
db: AsyncSession,
|
||
provider_id: str,
|
||
is_healthy: bool,
|
||
error: str | None = None,
|
||
) -> None:
|
||
"""更新供应商健康状态(含熔断逻辑)。"""
|
||
result = await db.execute(
|
||
select(ProviderHealth).where(ProviderHealth.provider_id == provider_id)
|
||
)
|
||
health = result.scalar_one_or_none()
|
||
|
||
now = datetime.utcnow()
|
||
|
||
if health is None:
|
||
health = ProviderHealth(
|
||
provider_id=provider_id,
|
||
is_healthy=is_healthy,
|
||
last_check=now,
|
||
consecutive_failures=0 if is_healthy else 1,
|
||
last_error=error,
|
||
)
|
||
db.add(health)
|
||
else:
|
||
health.last_check = now
|
||
|
||
if is_healthy:
|
||
health.is_healthy = True
|
||
health.consecutive_failures = 0
|
||
health.last_error = None
|
||
else:
|
||
health.consecutive_failures += 1
|
||
health.last_error = error
|
||
|
||
# 熔断逻辑
|
||
if health.consecutive_failures >= CIRCUIT_BREAKER_THRESHOLD:
|
||
health.is_healthy = False
|
||
logger.warning(
|
||
"circuit_breaker_triggered",
|
||
provider_id=provider_id,
|
||
consecutive_failures=health.consecutive_failures,
|
||
)
|
||
|
||
await db.commit()
|
||
|
||
async def record_call_result(
|
||
self,
|
||
db: AsyncSession,
|
||
provider_id: str,
|
||
success: bool,
|
||
error: str | None = None,
|
||
) -> None:
|
||
"""根据调用结果更新健康状态。"""
|
||
await self.update_health_status(db, provider_id, success, error)
|
||
|
||
async def get_healthy_providers(
|
||
self,
|
||
db: AsyncSession,
|
||
provider_ids: list[str],
|
||
) -> list[str]:
|
||
"""获取健康的供应商列表。"""
|
||
if not provider_ids:
|
||
return []
|
||
|
||
# 查询所有已记录的健康状态
|
||
result = await db.execute(
|
||
select(ProviderHealth.provider_id, ProviderHealth.is_healthy).where(
|
||
ProviderHealth.provider_id.in_(provider_ids),
|
||
)
|
||
)
|
||
health_map = {row[0]: row[1] for row in result.all()}
|
||
|
||
# 未记录的供应商默认健康,已记录但不健康的排除
|
||
return [
|
||
pid for pid in provider_ids
|
||
if pid not in health_map or health_map[pid]
|
||
]
|
||
|
||
async def is_healthy(
|
||
self,
|
||
db: AsyncSession,
|
||
provider_id: str,
|
||
) -> bool:
|
||
"""检查供应商是否健康。"""
|
||
result = await db.execute(
|
||
select(ProviderHealth).where(ProviderHealth.provider_id == provider_id)
|
||
)
|
||
health = result.scalar_one_or_none()
|
||
|
||
if health is None:
|
||
return True # 未记录默认健康
|
||
|
||
# 检查是否可以恢复
|
||
if not health.is_healthy and health.last_check:
|
||
recovery_time = health.last_check + timedelta(seconds=CIRCUIT_BREAKER_RECOVERY_SECONDS)
|
||
if datetime.utcnow() >= recovery_time:
|
||
return True # 允许重试
|
||
|
||
return health.is_healthy
|
||
|
||
|
||
# 全局单例
|
||
metrics_collector = MetricsCollector()
|
||
health_checker = HealthChecker()
|