Some checks are pending
Build and Push Docker Images / changes (push) Waiting to run
Build and Push Docker Images / build-backend (push) Blocked by required conditions
Build and Push Docker Images / build-frontend (push) Blocked by required conditions
Build and Push Docker Images / build-admin-frontend (push) Blocked by required conditions
197 lines
6.6 KiB
Python
197 lines
6.6 KiB
Python
"""成本追踪服务。
|
|
|
|
记录 API 调用成本,支持预算控制。
|
|
"""
|
|
|
|
from datetime import datetime, timedelta
|
|
from decimal import Decimal
|
|
|
|
from sqlalchemy import func, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.logging import get_logger
|
|
from app.db.admin_models import CostRecord, UserBudget
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class BudgetExceededError(Exception):
|
|
"""预算超限错误。"""
|
|
|
|
def __init__(self, limit_type: str, used: Decimal, limit: Decimal):
|
|
self.limit_type = limit_type
|
|
self.used = used
|
|
self.limit = limit
|
|
super().__init__(f"{limit_type} 预算已超限: {used}/{limit} USD")
|
|
|
|
|
|
class CostTracker:
|
|
"""成本追踪器。"""
|
|
|
|
async def record_cost(
|
|
self,
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
provider_name: str,
|
|
capability: str,
|
|
estimated_cost: float,
|
|
provider_id: str | None = None,
|
|
) -> CostRecord:
|
|
"""记录一次 API 调用成本。"""
|
|
record = CostRecord(
|
|
user_id=user_id,
|
|
provider_id=provider_id,
|
|
provider_name=provider_name,
|
|
capability=capability,
|
|
estimated_cost=Decimal(str(estimated_cost)),
|
|
)
|
|
db.add(record)
|
|
await db.commit()
|
|
|
|
logger.debug(
|
|
"cost_recorded",
|
|
user_id=user_id,
|
|
provider=provider_name,
|
|
capability=capability,
|
|
cost=estimated_cost,
|
|
)
|
|
return record
|
|
|
|
async def get_daily_cost(self, db: AsyncSession, user_id: str) -> Decimal:
|
|
"""获取用户今日成本。"""
|
|
today_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
|
|
|
|
result = await db.execute(
|
|
select(func.sum(CostRecord.estimated_cost)).where(
|
|
CostRecord.user_id == user_id,
|
|
CostRecord.timestamp >= today_start,
|
|
)
|
|
)
|
|
total = result.scalar()
|
|
return Decimal(str(total)) if total else Decimal("0")
|
|
|
|
async def get_monthly_cost(self, db: AsyncSession, user_id: str) -> Decimal:
|
|
"""获取用户本月成本。"""
|
|
now = datetime.utcnow()
|
|
month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
|
|
|
result = await db.execute(
|
|
select(func.sum(CostRecord.estimated_cost)).where(
|
|
CostRecord.user_id == user_id,
|
|
CostRecord.timestamp >= month_start,
|
|
)
|
|
)
|
|
total = result.scalar()
|
|
return Decimal(str(total)) if total else Decimal("0")
|
|
|
|
async def get_cost_by_capability(
|
|
self,
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
days: int = 30,
|
|
) -> dict[str, Decimal]:
|
|
"""按能力类型统计成本。"""
|
|
since = datetime.utcnow() - timedelta(days=days)
|
|
|
|
result = await db.execute(
|
|
select(CostRecord.capability, func.sum(CostRecord.estimated_cost))
|
|
.where(CostRecord.user_id == user_id, CostRecord.timestamp >= since)
|
|
.group_by(CostRecord.capability)
|
|
)
|
|
return {row[0]: Decimal(str(row[1])) for row in result.all()}
|
|
|
|
async def check_budget(
|
|
self,
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
estimated_cost: float,
|
|
) -> bool:
|
|
"""检查预算是否允许此次调用。
|
|
|
|
Returns:
|
|
True 如果允许,否则抛出 BudgetExceededError
|
|
"""
|
|
budget = await self.get_user_budget(db, user_id)
|
|
if not budget or not budget.enabled:
|
|
return True
|
|
|
|
# 检查日预算
|
|
daily_cost = await self.get_daily_cost(db, user_id)
|
|
if daily_cost + Decimal(str(estimated_cost)) > budget.daily_limit_usd:
|
|
raise BudgetExceededError("日", daily_cost, budget.daily_limit_usd)
|
|
|
|
# 检查月预算
|
|
monthly_cost = await self.get_monthly_cost(db, user_id)
|
|
if monthly_cost + Decimal(str(estimated_cost)) > budget.monthly_limit_usd:
|
|
raise BudgetExceededError("月", monthly_cost, budget.monthly_limit_usd)
|
|
|
|
return True
|
|
|
|
async def get_user_budget(self, db: AsyncSession, user_id: str) -> UserBudget | None:
|
|
"""获取用户预算配置。"""
|
|
result = await db.execute(
|
|
select(UserBudget).where(UserBudget.user_id == user_id)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def set_user_budget(
|
|
self,
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
daily_limit: float | None = None,
|
|
monthly_limit: float | None = None,
|
|
alert_threshold: float | None = None,
|
|
enabled: bool | None = None,
|
|
) -> UserBudget:
|
|
"""设置用户预算。"""
|
|
budget = await self.get_user_budget(db, user_id)
|
|
|
|
if budget is None:
|
|
budget = UserBudget(user_id=user_id)
|
|
db.add(budget)
|
|
|
|
if daily_limit is not None:
|
|
budget.daily_limit_usd = Decimal(str(daily_limit))
|
|
if monthly_limit is not None:
|
|
budget.monthly_limit_usd = Decimal(str(monthly_limit))
|
|
if alert_threshold is not None:
|
|
budget.alert_threshold = Decimal(str(alert_threshold))
|
|
if enabled is not None:
|
|
budget.enabled = enabled
|
|
|
|
await db.commit()
|
|
await db.refresh(budget)
|
|
return budget
|
|
|
|
async def get_cost_summary(
|
|
self,
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
) -> dict:
|
|
"""获取用户成本摘要。"""
|
|
daily = await self.get_daily_cost(db, user_id)
|
|
monthly = await self.get_monthly_cost(db, user_id)
|
|
by_capability = await self.get_cost_by_capability(db, user_id)
|
|
budget = await self.get_user_budget(db, user_id)
|
|
|
|
return {
|
|
"daily_cost_usd": float(daily),
|
|
"monthly_cost_usd": float(monthly),
|
|
"by_capability": {k: float(v) for k, v in by_capability.items()},
|
|
"budget": {
|
|
"daily_limit_usd": float(budget.daily_limit_usd) if budget else None,
|
|
"monthly_limit_usd": float(budget.monthly_limit_usd) if budget else None,
|
|
"daily_usage_percent": float(daily / budget.daily_limit_usd * 100)
|
|
if budget and budget.daily_limit_usd
|
|
else None,
|
|
"monthly_usage_percent": float(monthly / budget.monthly_limit_usd * 100)
|
|
if budget and budget.monthly_limit_usd
|
|
else None,
|
|
"enabled": budget.enabled if budget else False,
|
|
},
|
|
}
|
|
|
|
|
|
# 全局单例
|
|
cost_tracker = CostTracker()
|