Files
dreamweaver/backend/app/services/cost_tracker.py
torin b8d3cb4644
Some checks are pending
Build and Push Docker Images / changes (push) Waiting to run
Build and Push Docker Images / build-backend (push) Blocked by required conditions
Build and Push Docker Images / build-frontend (push) Blocked by required conditions
Build and Push Docker Images / build-admin-frontend (push) Blocked by required conditions
wip: snapshot full local workspace state
2026-04-17 18:58:11 +08:00

197 lines
6.6 KiB
Python

"""成本追踪服务。
记录 API 调用成本,支持预算控制。
"""
from datetime import datetime, timedelta
from decimal import Decimal
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.admin_models import CostRecord, UserBudget
logger = get_logger(__name__)
class BudgetExceededError(Exception):
"""预算超限错误。"""
def __init__(self, limit_type: str, used: Decimal, limit: Decimal):
self.limit_type = limit_type
self.used = used
self.limit = limit
super().__init__(f"{limit_type} 预算已超限: {used}/{limit} USD")
class CostTracker:
"""成本追踪器。"""
async def record_cost(
self,
db: AsyncSession,
user_id: str,
provider_name: str,
capability: str,
estimated_cost: float,
provider_id: str | None = None,
) -> CostRecord:
"""记录一次 API 调用成本。"""
record = CostRecord(
user_id=user_id,
provider_id=provider_id,
provider_name=provider_name,
capability=capability,
estimated_cost=Decimal(str(estimated_cost)),
)
db.add(record)
await db.commit()
logger.debug(
"cost_recorded",
user_id=user_id,
provider=provider_name,
capability=capability,
cost=estimated_cost,
)
return record
async def get_daily_cost(self, db: AsyncSession, user_id: str) -> Decimal:
"""获取用户今日成本。"""
today_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
result = await db.execute(
select(func.sum(CostRecord.estimated_cost)).where(
CostRecord.user_id == user_id,
CostRecord.timestamp >= today_start,
)
)
total = result.scalar()
return Decimal(str(total)) if total else Decimal("0")
async def get_monthly_cost(self, db: AsyncSession, user_id: str) -> Decimal:
"""获取用户本月成本。"""
now = datetime.utcnow()
month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
result = await db.execute(
select(func.sum(CostRecord.estimated_cost)).where(
CostRecord.user_id == user_id,
CostRecord.timestamp >= month_start,
)
)
total = result.scalar()
return Decimal(str(total)) if total else Decimal("0")
async def get_cost_by_capability(
self,
db: AsyncSession,
user_id: str,
days: int = 30,
) -> dict[str, Decimal]:
"""按能力类型统计成本。"""
since = datetime.utcnow() - timedelta(days=days)
result = await db.execute(
select(CostRecord.capability, func.sum(CostRecord.estimated_cost))
.where(CostRecord.user_id == user_id, CostRecord.timestamp >= since)
.group_by(CostRecord.capability)
)
return {row[0]: Decimal(str(row[1])) for row in result.all()}
async def check_budget(
self,
db: AsyncSession,
user_id: str,
estimated_cost: float,
) -> bool:
"""检查预算是否允许此次调用。
Returns:
True 如果允许,否则抛出 BudgetExceededError
"""
budget = await self.get_user_budget(db, user_id)
if not budget or not budget.enabled:
return True
# 检查日预算
daily_cost = await self.get_daily_cost(db, user_id)
if daily_cost + Decimal(str(estimated_cost)) > budget.daily_limit_usd:
raise BudgetExceededError("", daily_cost, budget.daily_limit_usd)
# 检查月预算
monthly_cost = await self.get_monthly_cost(db, user_id)
if monthly_cost + Decimal(str(estimated_cost)) > budget.monthly_limit_usd:
raise BudgetExceededError("", monthly_cost, budget.monthly_limit_usd)
return True
async def get_user_budget(self, db: AsyncSession, user_id: str) -> UserBudget | None:
"""获取用户预算配置。"""
result = await db.execute(
select(UserBudget).where(UserBudget.user_id == user_id)
)
return result.scalar_one_or_none()
async def set_user_budget(
self,
db: AsyncSession,
user_id: str,
daily_limit: float | None = None,
monthly_limit: float | None = None,
alert_threshold: float | None = None,
enabled: bool | None = None,
) -> UserBudget:
"""设置用户预算。"""
budget = await self.get_user_budget(db, user_id)
if budget is None:
budget = UserBudget(user_id=user_id)
db.add(budget)
if daily_limit is not None:
budget.daily_limit_usd = Decimal(str(daily_limit))
if monthly_limit is not None:
budget.monthly_limit_usd = Decimal(str(monthly_limit))
if alert_threshold is not None:
budget.alert_threshold = Decimal(str(alert_threshold))
if enabled is not None:
budget.enabled = enabled
await db.commit()
await db.refresh(budget)
return budget
async def get_cost_summary(
self,
db: AsyncSession,
user_id: str,
) -> dict:
"""获取用户成本摘要。"""
daily = await self.get_daily_cost(db, user_id)
monthly = await self.get_monthly_cost(db, user_id)
by_capability = await self.get_cost_by_capability(db, user_id)
budget = await self.get_user_budget(db, user_id)
return {
"daily_cost_usd": float(daily),
"monthly_cost_usd": float(monthly),
"by_capability": {k: float(v) for k, v in by_capability.items()},
"budget": {
"daily_limit_usd": float(budget.daily_limit_usd) if budget else None,
"monthly_limit_usd": float(budget.monthly_limit_usd) if budget else None,
"daily_usage_percent": float(daily / budget.daily_limit_usd * 100)
if budget and budget.daily_limit_usd
else None,
"monthly_usage_percent": float(monthly / budget.monthly_limit_usd * 100)
if budget and budget.monthly_limit_usd
else None,
"enabled": budget.enabled if budget else False,
},
}
# 全局单例
cost_tracker = CostTracker()