Files
dreamweaver/backend/app/api/admin_providers.py

591 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from datetime import datetime
from typing import Any, Literal
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.admin_auth import admin_guard
from app.db.admin_models import Provider
from app.db.database import get_db
from app.services.adapters.registry import AdapterRegistry
from app.services.admin_evaluation_analytics import get_admin_evaluation_analytics
from app.services.admin_executor_coverage import get_admin_executor_coverage
from app.services.admin_generation_trace import get_admin_generation_job_trace
from app.services.admin_harness_readiness import get_admin_harness_readiness
from app.services.admin_provider_analytics import get_admin_provider_analytics
from app.services.cost_tracker import cost_tracker
from app.services.provider_policy import DEFAULT_PROVIDERS, list_capability_policies
from app.services.secret_service import SecretService
router = APIRouter(dependencies=[Depends(admin_guard)])
class ProviderCreate(BaseModel):
name: str
type: str = Field(..., pattern="^(text|image|tts|storybook|asr)$")
adapter: str
model: str | None = None
api_base: str | None = None
api_key: str | None = None # 可选,优先于 config_ref
timeout_ms: int = 60000
max_retries: int = 1
weight: int = 1
priority: int = 0
enabled: bool = True
config_json: dict | None = None
config_ref: str | None = None # 环境变量 key 名称(回退)
updated_by: str | None = None
class ProviderUpdate(ProviderCreate):
enabled: bool | None = None
api_key: str | None = None
config_json: dict | None = None
class ProviderResponse(BaseModel):
"""Provider 响应模型,隐藏敏感字段。"""
id: str
name: str
type: str
adapter: str
model: str | None = None
api_base: str | None = None
has_api_key: bool = False # 仅标识是否配置了 api_key不返回明文
timeout_ms: int = 60000
max_retries: int = 1
weight: int = 1
priority: int = 0
enabled: bool = True
config_ref: str | None = None
model_config = ConfigDict(from_attributes=True)
class ProviderAnalyticsBucket(BaseModel):
capability: str
adapter: str
call_count: int
success_count: int
failure_count: int
avg_latency_ms: float | None = None
estimated_cost_usd: float
class ProviderAnalyticsUserBucket(BaseModel):
user_id: str
call_count: int
success_count: int
failure_count: int
job_count: int
story_count: int
estimated_cost_usd: float
class ProviderAnalyticsFailureReason(BaseModel):
reason: str
count: int
class ProviderAnalyticsResponse(BaseModel):
scope: str
window_days: int | None = None
capability: str | None = None
total_calls: int
successful_calls: int
failed_calls: int
avg_latency_ms: float | None = None
estimated_cost_usd: float
user_count: int
job_count: int
story_count: int
voice_session_count: int = 0
voice_turn_count: int = 0
by_provider: list[ProviderAnalyticsBucket]
by_user: list[ProviderAnalyticsUserBucket]
failure_reasons: list[ProviderAnalyticsFailureReason]
class EvaluationAnalyticsArtifactBucket(BaseModel):
artifact: str
count: int
class EvaluationAnalyticsOutputModeBucket(BaseModel):
output_mode: str
count: int
class EvaluationAnalyticsScoreBandBucket(BaseModel):
band: str
count: int
class EvaluationAnalyticsDimensionScore(BaseModel):
dimension: str
average_score: float
count: int
class EvaluationAnalyticsQualityGateIssue(BaseModel):
code: str
count: int
class EvaluationAnalyticsFailureCategory(BaseModel):
category: str
count: int
class EvaluationAnalyticsWarning(BaseModel):
message: str
count: int
class EvaluationAnalyticsResponse(BaseModel):
scope: str
window_days: int | None = None
artifact: str | None = None
total_evaluations: int
passed_evaluations: int
blocked_evaluations: int
pass_rate: float
average_score: float | None = None
job_count: int
story_count: int
user_count: int
by_artifact: list[EvaluationAnalyticsArtifactBucket]
by_output_mode: list[EvaluationAnalyticsOutputModeBucket]
score_bands: list[EvaluationAnalyticsScoreBandBucket]
dimension_scores: list[EvaluationAnalyticsDimensionScore]
quality_gate_issues: list[EvaluationAnalyticsQualityGateIssue]
failure_categories: list[EvaluationAnalyticsFailureCategory]
warnings: list[EvaluationAnalyticsWarning]
class ExecutorCoveragePlanModeBucket(BaseModel):
plan_mode: str
count: int
class ExecutorCoverageOutputModeBucket(BaseModel):
output_mode: str
count: int
class ExecutorCoverageTaskKeyBucket(BaseModel):
task_key: str
count: int
class ExecutorCoverageAssetBucket(BaseModel):
asset: str
count: int
class ExecutorCoverageResponse(BaseModel):
scope: str
window_days: int | None = None
plan_mode: str | None = None
total_runs: int
total_planned_tasks: int
total_executed_tasks: int
total_ignored_tasks: int
coverage_ratio: float
job_count: int
story_count: int
user_count: int
by_plan_mode: list[ExecutorCoveragePlanModeBucket]
by_output_mode: list[ExecutorCoverageOutputModeBucket]
executed_task_keys: list[ExecutorCoverageTaskKeyBucket]
ignored_task_keys: list[ExecutorCoverageTaskKeyBucket]
result_assets: list[ExecutorCoverageAssetBucket]
class AdminGenerationJobEventResponse(BaseModel):
id: int
job_id: str
story_id: int | None = None
event_type: str
status: str
message: str | None = None
event_metadata: dict[str, Any] = Field(default_factory=dict)
created_at: datetime
class AdminGenerationJobTraceResponse(BaseModel):
id: str
user_id: str
story_id: int | None = None
output_mode: str
input_type: str
status: str
current_step: str
progress_percent: int
progress_label: str
is_terminal: bool
can_cancel: bool = False
can_retry: bool = False
result_snapshot: dict[str, Any] = Field(default_factory=dict)
error_message: str | None = None
request_payload: dict[str, Any] = Field(default_factory=dict)
executor_coverage: ExecutorCoverageResponse
events: list[AdminGenerationJobEventResponse] = Field(default_factory=list)
created_at: datetime
updated_at: datetime
class HarnessReadinessCheck(BaseModel):
code: str
status: Literal["ready", "needs_attention", "blocked"]
message: str
details: dict[str, Any] = Field(default_factory=dict)
class HarnessReadinessGoldenReplay(BaseModel):
passed: bool
total_cases: int
failed_case_ids: list[str]
coverage_summary: dict[str, dict[str, int]] = Field(default_factory=dict)
class HarnessReadinessThresholds(BaseModel):
min_runtime_evaluations: int
min_executor_runs: int
min_evaluation_pass_rate: float
min_evaluation_average_score: float
min_executor_coverage_ratio: float
class HarnessReadinessResponse(BaseModel):
scope: str
window_days: int | None = None
status: Literal["ready", "needs_attention", "blocked"]
thresholds: HarnessReadinessThresholds
checks: list[HarnessReadinessCheck]
golden_replay: HarnessReadinessGoldenReplay
evaluation_analytics: EvaluationAnalyticsResponse
executor_coverage: ExecutorCoverageResponse
@router.get("/providers/adapters")
async def list_available_adapters():
"""获取所有可用的适配器类型 (定义的类)。"""
return AdapterRegistry.list_adapters()
@router.get("/providers/defaults")
async def get_env_defaults():
"""获取当前环境变量定义的默认策略 (Read-Only)。"""
return DEFAULT_PROVIDERS
@router.get("/providers/capabilities")
async def list_provider_capabilities():
"""获取 Provider 能力分层与默认路由策略。"""
return list_capability_policies()
@router.get("/providers/analytics", response_model=ProviderAnalyticsResponse)
async def get_provider_analytics(
days: int | None = Query(default=None, ge=1, le=365),
capability: Literal["text", "image", "tts", "storybook", "asr"] | None = Query(
default=None
),
db: AsyncSession = Depends(get_db),
):
"""获取当前环境跨用户的 Provider 运营摘要。"""
return await get_admin_provider_analytics(
db,
days=days,
capability=capability,
)
@router.get("/evaluations/analytics", response_model=EvaluationAnalyticsResponse)
async def get_evaluation_analytics(
days: int | None = Query(default=None, ge=1, le=365),
artifact: Literal["story_text", "storybook_pages"] | None = Query(default=None),
db: AsyncSession = Depends(get_db),
):
"""获取内部内容评测摘要,仅供管理控制面使用。"""
return await get_admin_evaluation_analytics(
db,
days=days,
artifact=artifact,
)
@router.get("/executors/coverage", response_model=ExecutorCoverageResponse)
async def get_executor_coverage(
days: int | None = Query(default=None, ge=1, le=365),
plan_mode: Literal["asset_generation", "asset_retry"] | None = Query(default=None),
db: AsyncSession = Depends(get_db),
):
"""获取内部 executor 执行覆盖率,仅供管理控制面使用。"""
return await get_admin_executor_coverage(
db,
days=days,
plan_mode=plan_mode,
)
@router.get("/harness/readiness", response_model=HarnessReadinessResponse)
async def get_harness_readiness(
days: int | None = Query(default=None, ge=1, le=365),
db: AsyncSession = Depends(get_db),
):
"""获取内部 harness readiness 审查摘要,仅供管理控制面使用。"""
return await get_admin_harness_readiness(db, days=days)
@router.get(
"/generations/jobs/{job_id}/trace",
response_model=AdminGenerationJobTraceResponse,
)
async def get_generation_job_trace(
job_id: str,
db: AsyncSession = Depends(get_db),
):
"""获取完整内部生成链路,仅供管理控制面排查与审查使用。"""
return await get_admin_generation_job_trace(db, job_id=job_id)
@router.get("/providers", response_model=list[ProviderResponse])
async def list_providers(db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Provider))
providers = result.scalars().all()
# 转换为响应模型,隐藏 api_key 明文
return [
ProviderResponse(
id=p.id,
name=p.name,
type=p.type,
adapter=p.adapter,
model=p.model,
api_base=p.api_base,
has_api_key=bool(p.api_key), # 仅标识是否有 key
timeout_ms=p.timeout_ms,
max_retries=p.max_retries,
weight=p.weight,
priority=p.priority,
enabled=p.enabled,
config_ref=p.config_ref,
)
for p in providers
]
def _to_response(provider: Provider) -> ProviderResponse:
"""将 Provider 转换为响应模型,隐藏敏感字段。"""
return ProviderResponse(
id=provider.id,
name=provider.name,
type=provider.type,
adapter=provider.adapter,
model=provider.model,
api_base=provider.api_base,
has_api_key=bool(provider.api_key),
timeout_ms=provider.timeout_ms,
max_retries=provider.max_retries,
weight=provider.weight,
priority=provider.priority,
enabled=provider.enabled,
config_ref=provider.config_ref,
)
@router.post("/providers", response_model=ProviderResponse)
async def create_provider(payload: ProviderCreate, db: AsyncSession = Depends(get_db)):
data = payload.model_dump()
# 加密 API Key
if data.get("api_key"):
data["api_key"] = SecretService.encrypt(data["api_key"])
provider = Provider(**data)
db.add(provider)
await db.commit()
await db.refresh(provider)
return _to_response(provider)
@router.put("/providers/{provider_id}", response_model=ProviderResponse)
async def update_provider(
provider_id: str, payload: ProviderUpdate, db: AsyncSession = Depends(get_db)
):
result = await db.execute(select(Provider).where(Provider.id == provider_id))
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
data = payload.model_dump(exclude_unset=True)
# 加密 API Key
if "api_key" in data and data["api_key"]:
data["api_key"] = SecretService.encrypt(data["api_key"])
for k, v in data.items():
setattr(provider, k, v)
await db.commit()
await db.refresh(provider)
return _to_response(provider)
@router.delete("/providers/{provider_id}")
async def delete_provider(provider_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Provider).where(Provider.id == provider_id))
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
await db.delete(provider)
await db.commit()
return {"message": "deleted"}
# ==================== 密钥管理 API ====================
class SecretCreate(BaseModel):
"""密钥创建请求。"""
name: str = Field(..., description="密钥名称,如 CQTAI_API_KEY")
value: str = Field(..., description="密钥明文值")
class SecretResponse(BaseModel):
"""密钥响应,不返回明文。"""
name: str
created_at: str | None = None
updated_at: str | None = None
@router.get("/secrets", response_model=list[str])
async def list_secrets(db: AsyncSession = Depends(get_db)):
"""列出所有密钥名称(不返回值)。"""
return await SecretService.list_secrets(db)
@router.post("/secrets", response_model=SecretResponse)
async def create_or_update_secret(payload: SecretCreate, db: AsyncSession = Depends(get_db)):
"""创建或更新密钥。"""
secret = await SecretService.set_secret(db, payload.name, payload.value)
return SecretResponse(
name=secret.name,
created_at=secret.created_at.isoformat() if secret.created_at else None,
updated_at=secret.updated_at.isoformat() if secret.updated_at else None,
)
@router.delete("/secrets/{name}")
async def delete_secret(name: str, db: AsyncSession = Depends(get_db)):
"""删除密钥。"""
deleted = await SecretService.delete_secret(db, name)
if not deleted:
raise HTTPException(status_code=404, detail="Secret not found")
return {"message": "deleted"}
@router.get("/secrets/{name}/verify")
async def verify_secret(name: str, db: AsyncSession = Depends(get_db)):
"""验证密钥是否存在且可解密(不返回明文)。"""
value = await SecretService.get_secret(db, name)
if value is None:
raise HTTPException(status_code=404, detail="Secret not found")
return {"name": name, "valid": True, "length": len(value)}
# ==================== 成本追踪 API ====================
class BudgetUpdate(BaseModel):
"""预算更新请求。"""
daily_limit_usd: float | None = None
monthly_limit_usd: float | None = None
alert_threshold: float | None = Field(default=None, ge=0, le=1)
enabled: bool | None = None
@router.get("/costs/summary/{user_id}")
async def get_user_cost_summary(user_id: str, db: AsyncSession = Depends(get_db)):
"""获取用户成本摘要。"""
return await cost_tracker.get_cost_summary(db, user_id)
@router.get("/costs/all")
async def get_all_costs_summary(db: AsyncSession = Depends(get_db)):
"""获取所有用户成本汇总(管理员)。"""
from sqlalchemy import func
from app.db.admin_models import CostRecord
# 按用户汇总
result = await db.execute(
select(
CostRecord.user_id,
func.sum(CostRecord.estimated_cost).label("total_cost"),
func.count().label("call_count"),
).group_by(CostRecord.user_id)
)
users = [
{"user_id": row[0], "total_cost_usd": float(row[1]), "call_count": row[2]}
for row in result.all()
]
# 按能力汇总
result = await db.execute(
select(
CostRecord.capability,
func.sum(CostRecord.estimated_cost).label("total_cost"),
func.count().label("call_count"),
).group_by(CostRecord.capability)
)
capabilities = [
{"capability": row[0], "total_cost_usd": float(row[1]), "call_count": row[2]}
for row in result.all()
]
return {"by_user": users, "by_capability": capabilities}
@router.get("/budgets/{user_id}")
async def get_user_budget(user_id: str, db: AsyncSession = Depends(get_db)):
"""获取用户预算配置。"""
budget = await cost_tracker.get_user_budget(db, user_id)
if not budget:
return {"user_id": user_id, "budget": None}
return {
"user_id": user_id,
"budget": {
"daily_limit_usd": float(budget.daily_limit_usd),
"monthly_limit_usd": float(budget.monthly_limit_usd),
"alert_threshold": float(budget.alert_threshold),
"enabled": budget.enabled,
},
}
@router.post("/budgets/{user_id}")
async def set_user_budget(
user_id: str, payload: BudgetUpdate, db: AsyncSession = Depends(get_db)
):
"""设置用户预算。"""
budget = await cost_tracker.set_user_budget(
db,
user_id,
daily_limit=payload.daily_limit_usd,
monthly_limit=payload.monthly_limit_usd,
alert_threshold=payload.alert_threshold,
enabled=payload.enabled,
)
return {
"user_id": user_id,
"budget": {
"daily_limit_usd": float(budget.daily_limit_usd),
"monthly_limit_usd": float(budget.monthly_limit_usd),
"alert_threshold": float(budget.alert_threshold),
"enabled": budget.enabled,
},
}