"""成本追踪服务。 记录 API 调用成本,支持预算控制。 """ from datetime import datetime, timedelta from decimal import Decimal from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from app.core.logging import get_logger from app.db.admin_models import CostRecord, UserBudget logger = get_logger(__name__) class BudgetExceededError(Exception): """预算超限错误。""" def __init__(self, limit_type: str, used: Decimal, limit: Decimal): self.limit_type = limit_type self.used = used self.limit = limit super().__init__(f"{limit_type} 预算已超限: {used}/{limit} USD") class CostTracker: """成本追踪器。""" async def record_cost( self, db: AsyncSession, user_id: str, provider_name: str, capability: str, estimated_cost: float, provider_id: str | None = None, ) -> CostRecord: """记录一次 API 调用成本。""" record = CostRecord( user_id=user_id, provider_id=provider_id, provider_name=provider_name, capability=capability, estimated_cost=Decimal(str(estimated_cost)), ) db.add(record) await db.commit() logger.debug( "cost_recorded", user_id=user_id, provider=provider_name, capability=capability, cost=estimated_cost, ) return record async def get_daily_cost(self, db: AsyncSession, user_id: str) -> Decimal: """获取用户今日成本。""" today_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0) result = await db.execute( select(func.sum(CostRecord.estimated_cost)).where( CostRecord.user_id == user_id, CostRecord.timestamp >= today_start, ) ) total = result.scalar() return Decimal(str(total)) if total else Decimal("0") async def get_monthly_cost(self, db: AsyncSession, user_id: str) -> Decimal: """获取用户本月成本。""" now = datetime.utcnow() month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) result = await db.execute( select(func.sum(CostRecord.estimated_cost)).where( CostRecord.user_id == user_id, CostRecord.timestamp >= month_start, ) ) total = result.scalar() return Decimal(str(total)) if total else Decimal("0") async def get_cost_by_capability( self, db: AsyncSession, user_id: str, days: int = 30, ) -> dict[str, Decimal]: """按能力类型统计成本。""" since = datetime.utcnow() - timedelta(days=days) result = await db.execute( select(CostRecord.capability, func.sum(CostRecord.estimated_cost)) .where(CostRecord.user_id == user_id, CostRecord.timestamp >= since) .group_by(CostRecord.capability) ) return {row[0]: Decimal(str(row[1])) for row in result.all()} async def check_budget( self, db: AsyncSession, user_id: str, estimated_cost: float, ) -> bool: """检查预算是否允许此次调用。 Returns: True 如果允许,否则抛出 BudgetExceededError """ budget = await self.get_user_budget(db, user_id) if not budget or not budget.enabled: return True # 检查日预算 daily_cost = await self.get_daily_cost(db, user_id) if daily_cost + Decimal(str(estimated_cost)) > budget.daily_limit_usd: raise BudgetExceededError("日", daily_cost, budget.daily_limit_usd) # 检查月预算 monthly_cost = await self.get_monthly_cost(db, user_id) if monthly_cost + Decimal(str(estimated_cost)) > budget.monthly_limit_usd: raise BudgetExceededError("月", monthly_cost, budget.monthly_limit_usd) return True async def get_user_budget(self, db: AsyncSession, user_id: str) -> UserBudget | None: """获取用户预算配置。""" result = await db.execute( select(UserBudget).where(UserBudget.user_id == user_id) ) return result.scalar_one_or_none() async def set_user_budget( self, db: AsyncSession, user_id: str, daily_limit: float | None = None, monthly_limit: float | None = None, alert_threshold: float | None = None, enabled: bool | None = None, ) -> UserBudget: """设置用户预算。""" budget = await self.get_user_budget(db, user_id) if budget is None: budget = UserBudget(user_id=user_id) db.add(budget) if daily_limit is not None: budget.daily_limit_usd = Decimal(str(daily_limit)) if monthly_limit is not None: budget.monthly_limit_usd = Decimal(str(monthly_limit)) if alert_threshold is not None: budget.alert_threshold = Decimal(str(alert_threshold)) if enabled is not None: budget.enabled = enabled await db.commit() await db.refresh(budget) return budget async def get_cost_summary( self, db: AsyncSession, user_id: str, ) -> dict: """获取用户成本摘要。""" daily = await self.get_daily_cost(db, user_id) monthly = await self.get_monthly_cost(db, user_id) by_capability = await self.get_cost_by_capability(db, user_id) budget = await self.get_user_budget(db, user_id) return { "daily_cost_usd": float(daily), "monthly_cost_usd": float(monthly), "by_capability": {k: float(v) for k, v in by_capability.items()}, "budget": { "daily_limit_usd": float(budget.daily_limit_usd) if budget else None, "monthly_limit_usd": float(budget.monthly_limit_usd) if budget else None, "daily_usage_percent": float(daily / budget.daily_limit_usd * 100) if budget and budget.daily_limit_usd else None, "monthly_usage_percent": float(monthly / budget.monthly_limit_usd * 100) if budget and budget.monthly_limit_usd else None, "enabled": budget.enabled if budget else False, }, } # 全局单例 cost_tracker = CostTracker()