from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel, ConfigDict, Field from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.core.admin_auth import admin_guard from app.db.admin_models import Provider from app.db.database import get_db from app.services.cost_tracker import cost_tracker from app.services.secret_service import SecretService router = APIRouter(dependencies=[Depends(admin_guard)]) class ProviderCreate(BaseModel): name: str type: str = Field(..., pattern="^(text|image|tts|storybook)$") adapter: str model: str | None = None api_base: str | None = None api_key: str | None = None # 可选,优先于 config_ref timeout_ms: int = 60000 max_retries: int = 1 weight: int = 1 priority: int = 0 enabled: bool = True config_json: dict | None = None config_ref: str | None = None # 环境变量 key 名称(回退) updated_by: str | None = None class ProviderUpdate(ProviderCreate): enabled: bool | None = None api_key: str | None = None config_json: dict | None = None class ProviderResponse(BaseModel): """Provider 响应模型,隐藏敏感字段。""" id: str name: str type: str adapter: str model: str | None = None api_base: str | None = None has_api_key: bool = False # 仅标识是否配置了 api_key,不返回明文 timeout_ms: int = 60000 max_retries: int = 1 weight: int = 1 priority: int = 0 enabled: bool = True config_ref: str | None = None model_config = ConfigDict(from_attributes=True) from app.services.adapters.registry import AdapterRegistry from app.services.provider_router import DEFAULT_PROVIDERS @router.get("/providers/adapters") async def list_available_adapters(): """获取所有可用的适配器类型 (定义的类)。""" return AdapterRegistry.list_adapters() @router.get("/providers/defaults") async def get_env_defaults(): """获取当前环境变量定义的默认策略 (Read-Only)。""" return DEFAULT_PROVIDERS @router.get("/providers", response_model=list[ProviderResponse]) async def list_providers(db: AsyncSession = Depends(get_db)): result = await db.execute(select(Provider)) providers = result.scalars().all() # 转换为响应模型,隐藏 api_key 明文 return [ ProviderResponse( id=p.id, name=p.name, type=p.type, adapter=p.adapter, model=p.model, api_base=p.api_base, has_api_key=bool(p.api_key), # 仅标识是否有 key timeout_ms=p.timeout_ms, max_retries=p.max_retries, weight=p.weight, priority=p.priority, enabled=p.enabled, config_ref=p.config_ref, ) for p in providers ] def _to_response(provider: Provider) -> ProviderResponse: """将 Provider 转换为响应模型,隐藏敏感字段。""" return ProviderResponse( id=provider.id, name=provider.name, type=provider.type, adapter=provider.adapter, model=provider.model, api_base=provider.api_base, has_api_key=bool(provider.api_key), timeout_ms=provider.timeout_ms, max_retries=provider.max_retries, weight=provider.weight, priority=provider.priority, enabled=provider.enabled, config_ref=provider.config_ref, ) @router.post("/providers", response_model=ProviderResponse) async def create_provider(payload: ProviderCreate, db: AsyncSession = Depends(get_db)): data = payload.model_dump() # 加密 API Key if data.get("api_key"): data["api_key"] = SecretService.encrypt(data["api_key"]) provider = Provider(**data) db.add(provider) await db.commit() await db.refresh(provider) return _to_response(provider) @router.put("/providers/{provider_id}", response_model=ProviderResponse) async def update_provider( provider_id: str, payload: ProviderUpdate, db: AsyncSession = Depends(get_db) ): result = await db.execute(select(Provider).where(Provider.id == provider_id)) provider = result.scalar_one_or_none() if not provider: raise HTTPException(status_code=404, detail="Provider not found") data = payload.model_dump(exclude_unset=True) # 加密 API Key if "api_key" in data and data["api_key"]: data["api_key"] = SecretService.encrypt(data["api_key"]) for k, v in data.items(): setattr(provider, k, v) await db.commit() await db.refresh(provider) return _to_response(provider) @router.delete("/providers/{provider_id}") async def delete_provider(provider_id: str, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Provider).where(Provider.id == provider_id)) provider = result.scalar_one_or_none() if not provider: raise HTTPException(status_code=404, detail="Provider not found") await db.delete(provider) await db.commit() return {"message": "deleted"} # ==================== 密钥管理 API ==================== class SecretCreate(BaseModel): """密钥创建请求。""" name: str = Field(..., description="密钥名称,如 CQTAI_API_KEY") value: str = Field(..., description="密钥明文值") class SecretResponse(BaseModel): """密钥响应,不返回明文。""" name: str created_at: str | None = None updated_at: str | None = None @router.get("/secrets", response_model=list[str]) async def list_secrets(db: AsyncSession = Depends(get_db)): """列出所有密钥名称(不返回值)。""" return await SecretService.list_secrets(db) @router.post("/secrets", response_model=SecretResponse) async def create_or_update_secret(payload: SecretCreate, db: AsyncSession = Depends(get_db)): """创建或更新密钥。""" secret = await SecretService.set_secret(db, payload.name, payload.value) return SecretResponse( name=secret.name, created_at=secret.created_at.isoformat() if secret.created_at else None, updated_at=secret.updated_at.isoformat() if secret.updated_at else None, ) @router.delete("/secrets/{name}") async def delete_secret(name: str, db: AsyncSession = Depends(get_db)): """删除密钥。""" deleted = await SecretService.delete_secret(db, name) if not deleted: raise HTTPException(status_code=404, detail="Secret not found") return {"message": "deleted"} @router.get("/secrets/{name}/verify") async def verify_secret(name: str, db: AsyncSession = Depends(get_db)): """验证密钥是否存在且可解密(不返回明文)。""" value = await SecretService.get_secret(db, name) if value is None: raise HTTPException(status_code=404, detail="Secret not found") return {"name": name, "valid": True, "length": len(value)} # ==================== 成本追踪 API ==================== class BudgetUpdate(BaseModel): """预算更新请求。""" daily_limit_usd: float | None = None monthly_limit_usd: float | None = None alert_threshold: float | None = Field(default=None, ge=0, le=1) enabled: bool | None = None @router.get("/costs/summary/{user_id}") async def get_user_cost_summary(user_id: str, db: AsyncSession = Depends(get_db)): """获取用户成本摘要。""" return await cost_tracker.get_cost_summary(db, user_id) @router.get("/costs/all") async def get_all_costs_summary(db: AsyncSession = Depends(get_db)): """获取所有用户成本汇总(管理员)。""" from sqlalchemy import func from app.db.admin_models import CostRecord # 按用户汇总 result = await db.execute( select( CostRecord.user_id, func.sum(CostRecord.estimated_cost).label("total_cost"), func.count().label("call_count"), ).group_by(CostRecord.user_id) ) users = [ {"user_id": row[0], "total_cost_usd": float(row[1]), "call_count": row[2]} for row in result.all() ] # 按能力汇总 result = await db.execute( select( CostRecord.capability, func.sum(CostRecord.estimated_cost).label("total_cost"), func.count().label("call_count"), ).group_by(CostRecord.capability) ) capabilities = [ {"capability": row[0], "total_cost_usd": float(row[1]), "call_count": row[2]} for row in result.all() ] return {"by_user": users, "by_capability": capabilities} @router.get("/budgets/{user_id}") async def get_user_budget(user_id: str, db: AsyncSession = Depends(get_db)): """获取用户预算配置。""" budget = await cost_tracker.get_user_budget(db, user_id) if not budget: return {"user_id": user_id, "budget": None} return { "user_id": user_id, "budget": { "daily_limit_usd": float(budget.daily_limit_usd), "monthly_limit_usd": float(budget.monthly_limit_usd), "alert_threshold": float(budget.alert_threshold), "enabled": budget.enabled, }, } @router.post("/budgets/{user_id}") async def set_user_budget( user_id: str, payload: BudgetUpdate, db: AsyncSession = Depends(get_db) ): """设置用户预算。""" budget = await cost_tracker.set_user_budget( db, user_id, daily_limit=payload.daily_limit_usd, monthly_limit=payload.monthly_limit_usd, alert_threshold=payload.alert_threshold, enabled=payload.enabled, ) return { "user_id": user_id, "budget": { "daily_limit_usd": float(budget.daily_limit_usd), "monthly_limit_usd": float(budget.monthly_limit_usd), "alert_threshold": float(budget.alert_threshold), "enabled": budget.enabled, }, }