Files
dreamweaver/backend/app/services/provider_metrics.py
torin b8d3cb4644
Some checks are pending
Build and Push Docker Images / changes (push) Waiting to run
Build and Push Docker Images / build-backend (push) Blocked by required conditions
Build and Push Docker Images / build-frontend (push) Blocked by required conditions
Build and Push Docker Images / build-admin-frontend (push) Blocked by required conditions
wip: snapshot full local workspace state
2026-04-17 18:58:11 +08:00

249 lines
7.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""供应商指标收集和健康检查服务。"""
from datetime import datetime, timedelta
from decimal import Decimal
from typing import TYPE_CHECKING
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.admin_models import ProviderHealth, ProviderMetrics
if TYPE_CHECKING:
from app.services.adapters.base import BaseAdapter
logger = get_logger(__name__)
# 熔断阈值:连续失败次数
CIRCUIT_BREAKER_THRESHOLD = 3
# 熔断恢复时间(秒)
CIRCUIT_BREAKER_RECOVERY_SECONDS = 60
class MetricsCollector:
"""供应商调用指标收集器。"""
async def record_call(
self,
db: AsyncSession,
provider_id: str,
success: bool,
latency_ms: int | None = None,
cost_usd: float | None = None,
error_message: str | None = None,
request_id: str | None = None,
) -> None:
"""记录一次 API 调用。"""
metric = ProviderMetrics(
provider_id=provider_id,
success=success,
latency_ms=latency_ms,
cost_usd=Decimal(str(cost_usd)) if cost_usd else None,
error_message=error_message,
request_id=request_id,
)
db.add(metric)
await db.commit()
logger.debug(
"metrics_recorded",
provider_id=provider_id,
success=success,
latency_ms=latency_ms,
)
async def get_success_rate(
self,
db: AsyncSession,
provider_id: str,
window_minutes: int = 60,
) -> float:
"""获取指定时间窗口内的成功率。"""
since = datetime.utcnow() - timedelta(minutes=window_minutes)
result = await db.execute(
select(
func.count().filter(ProviderMetrics.success.is_(True)).label("success_count"),
func.count().label("total_count"),
).where(
ProviderMetrics.provider_id == provider_id,
ProviderMetrics.timestamp >= since,
)
)
row = result.one()
success_count, total_count = row.success_count, row.total_count
if total_count == 0:
return 1.0 # 无数据时假设健康
return success_count / total_count
async def get_avg_latency(
self,
db: AsyncSession,
provider_id: str,
window_minutes: int = 60,
) -> float:
"""获取指定时间窗口内的平均延迟(毫秒)。"""
since = datetime.utcnow() - timedelta(minutes=window_minutes)
result = await db.execute(
select(func.avg(ProviderMetrics.latency_ms)).where(
ProviderMetrics.provider_id == provider_id,
ProviderMetrics.timestamp >= since,
ProviderMetrics.latency_ms.isnot(None),
)
)
avg = result.scalar()
return float(avg) if avg else 0.0
async def get_total_cost(
self,
db: AsyncSession,
provider_id: str,
window_minutes: int = 60,
) -> float:
"""获取指定时间窗口内的总成本USD"""
since = datetime.utcnow() - timedelta(minutes=window_minutes)
result = await db.execute(
select(func.sum(ProviderMetrics.cost_usd)).where(
ProviderMetrics.provider_id == provider_id,
ProviderMetrics.timestamp >= since,
)
)
total = result.scalar()
return float(total) if total else 0.0
class HealthChecker:
"""供应商健康检查器。"""
async def check_provider(
self,
db: AsyncSession,
provider_id: str,
adapter: "BaseAdapter",
) -> bool:
"""执行健康检查并更新状态。"""
try:
is_healthy = await adapter.health_check()
except Exception as e:
logger.warning("health_check_failed", provider_id=provider_id, error=str(e))
is_healthy = False
await self.update_health_status(
db,
provider_id,
is_healthy,
error=None if is_healthy else "Health check failed",
)
return is_healthy
async def update_health_status(
self,
db: AsyncSession,
provider_id: str,
is_healthy: bool,
error: str | None = None,
) -> None:
"""更新供应商健康状态(含熔断逻辑)。"""
result = await db.execute(
select(ProviderHealth).where(ProviderHealth.provider_id == provider_id)
)
health = result.scalar_one_or_none()
now = datetime.utcnow()
if health is None:
health = ProviderHealth(
provider_id=provider_id,
is_healthy=is_healthy,
last_check=now,
consecutive_failures=0 if is_healthy else 1,
last_error=error,
)
db.add(health)
else:
health.last_check = now
if is_healthy:
health.is_healthy = True
health.consecutive_failures = 0
health.last_error = None
else:
health.consecutive_failures += 1
health.last_error = error
# 熔断逻辑
if health.consecutive_failures >= CIRCUIT_BREAKER_THRESHOLD:
health.is_healthy = False
logger.warning(
"circuit_breaker_triggered",
provider_id=provider_id,
consecutive_failures=health.consecutive_failures,
)
await db.commit()
async def record_call_result(
self,
db: AsyncSession,
provider_id: str,
success: bool,
error: str | None = None,
) -> None:
"""根据调用结果更新健康状态。"""
await self.update_health_status(db, provider_id, success, error)
async def get_healthy_providers(
self,
db: AsyncSession,
provider_ids: list[str],
) -> list[str]:
"""获取健康的供应商列表。"""
if not provider_ids:
return []
# 查询所有已记录的健康状态
result = await db.execute(
select(ProviderHealth.provider_id, ProviderHealth.is_healthy).where(
ProviderHealth.provider_id.in_(provider_ids),
)
)
health_map = {row[0]: row[1] for row in result.all()}
# 未记录的供应商默认健康,已记录但不健康的排除
return [
pid for pid in provider_ids
if pid not in health_map or health_map[pid]
]
async def is_healthy(
self,
db: AsyncSession,
provider_id: str,
) -> bool:
"""检查供应商是否健康。"""
result = await db.execute(
select(ProviderHealth).where(ProviderHealth.provider_id == provider_id)
)
health = result.scalar_one_or_none()
if health is None:
return True # 未记录默认健康
# 检查是否可以恢复
if not health.is_healthy and health.last_check:
recovery_time = health.last_check + timedelta(seconds=CIRCUIT_BREAKER_RECOVERY_SECONDS)
if datetime.utcnow() >= recovery_time:
return True # 允许重试
return health.is_healthy
# 全局单例
metrics_collector = MetricsCollector()
health_checker = HealthChecker()