929 lines
28 KiB
Python
929 lines
28 KiB
Python
"""Lightweight generation job/event tracking."""
|
||
|
||
from __future__ import annotations
|
||
|
||
from datetime import datetime, timedelta, timezone
|
||
from typing import Any
|
||
|
||
from fastapi import HTTPException
|
||
from sqlalchemy import desc, select, update
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.core.config import settings
|
||
from app.core.logging import get_logger
|
||
from app.db.models import GenerationJob, GenerationJobEvent, Story
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
def _is_terminal_status(status: str) -> bool:
|
||
return status in {"completed", "degraded_completed", "failed", "canceled"}
|
||
|
||
|
||
def _job_supports_queue_control(job: GenerationJob) -> bool:
|
||
return job.output_mode in {"story", "storybook", "asset_generation"}
|
||
|
||
|
||
def generation_job_can_cancel(job: GenerationJob) -> bool:
|
||
return (
|
||
_job_supports_queue_control(job)
|
||
and job.status == "running"
|
||
and job.current_step != "cancel_requested"
|
||
)
|
||
|
||
|
||
def generation_job_can_retry(job: GenerationJob) -> bool:
|
||
return _job_supports_queue_control(job) and job.status in {"failed", "canceled"}
|
||
|
||
|
||
def _story_snapshot(story: Story | None) -> dict[str, Any]:
|
||
if story is None:
|
||
return {}
|
||
|
||
return {
|
||
"story_id": story.id,
|
||
"mode": story.mode,
|
||
"generation_status": story.generation_status,
|
||
"text_status": story.text_status,
|
||
"image_status": story.image_status,
|
||
"audio_status": story.audio_status,
|
||
"retryable_assets": story.retryable_assets,
|
||
"last_error": story.last_error,
|
||
}
|
||
|
||
|
||
def _job_status_from_story(story: Story) -> str:
|
||
if story.generation_status == "failed":
|
||
return "failed"
|
||
if story.generation_status == "degraded_completed":
|
||
return "degraded_completed"
|
||
return "completed"
|
||
|
||
|
||
def _job_progress(job: GenerationJob) -> dict[str, Any]:
|
||
"""Resolve a compact progress summary for polling-oriented clients."""
|
||
|
||
if job.status == "failed":
|
||
return {
|
||
"progress_percent": 100,
|
||
"progress_label": "生成失败",
|
||
"is_terminal": True,
|
||
}
|
||
|
||
if job.status == "canceled":
|
||
return {
|
||
"progress_percent": 100,
|
||
"progress_label": "已取消",
|
||
"is_terminal": True,
|
||
}
|
||
|
||
if job.status in {"completed", "degraded_completed"}:
|
||
return {
|
||
"progress_percent": 100,
|
||
"progress_label": "已完成" if job.status == "completed" else "降级完成",
|
||
"is_terminal": True,
|
||
}
|
||
|
||
progress_map: dict[str, tuple[int, str]] = {
|
||
"request_accepted": (5, "已接收请求"),
|
||
"retry_queued": (8, "重新排队中"),
|
||
"worker_started": (12, "后台任务已开始"),
|
||
"cancel_requested": (15, "已请求取消"),
|
||
"context_prepared": (20, "上下文已准备"),
|
||
"narrative_generated": (45, "正文已生成"),
|
||
"story_saved": (60, "主记录已保存"),
|
||
"provider_call_started": (65, "Provider 调用中"),
|
||
"provider_call_succeeded": (72, "Provider 调用成功"),
|
||
"provider_call_failed": (72, "Provider 调用失败,尝试恢复"),
|
||
"cover_image_started": (75, "封面生成中"),
|
||
"cover_image_succeeded": (88, "封面已生成"),
|
||
"cover_image_failed": (88, "封面生成失败"),
|
||
"storybook_images_started": (75, "绘本插图生成中"),
|
||
"storybook_cover_image_succeeded": (82, "绘本封面已生成"),
|
||
"storybook_cover_image_failed": (82, "绘本封面生成失败"),
|
||
"storybook_page_image_succeeded": (86, "分页插图已生成"),
|
||
"storybook_page_image_failed": (86, "分页插图生成失败"),
|
||
"storybook_images_completed": (92, "绘本插图已完成"),
|
||
"audio_started": (75, "音频生成中"),
|
||
"audio_cache_hit": (88, "音频缓存已复用"),
|
||
"audio_succeeded": (88, "音频已生成"),
|
||
"audio_failed": (88, "音频生成失败"),
|
||
"asset_retry_started": (25, "资源重试中"),
|
||
"postprocessing_queued": (90, "后处理已排队"),
|
||
"asset_generation_completed": (100, "资源已完成"),
|
||
"asset_retry_completed": (100, "资源重试完成"),
|
||
"generation_canceled": (100, "任务已取消"),
|
||
"generation_completed": (100, "生成完成"),
|
||
"generation_stale_failed": (100, "任务超时已收敛"),
|
||
}
|
||
percent, label = progress_map.get(job.current_step, (10, "生成处理中"))
|
||
return {
|
||
"progress_percent": percent,
|
||
"progress_label": label,
|
||
"is_terminal": percent >= 100,
|
||
}
|
||
|
||
|
||
def _normalize_datetime(value: datetime) -> datetime:
|
||
if value.tzinfo is None:
|
||
return value.replace(tzinfo=timezone.utc)
|
||
return value.astimezone(timezone.utc)
|
||
|
||
|
||
def _is_stale_job(job: GenerationJob, *, stale_after_minutes: int) -> bool:
|
||
cutoff = datetime.now(timezone.utc) - timedelta(minutes=stale_after_minutes)
|
||
return job.status == "running" and _normalize_datetime(job.updated_at) <= cutoff
|
||
|
||
|
||
def _failure_label(job: GenerationJob) -> str:
|
||
if job.status == "canceled":
|
||
return "任务已取消"
|
||
if job.current_step == "generation_stale_failed":
|
||
return "任务超时"
|
||
if job.output_mode == "asset_retry":
|
||
return "资源重试失败"
|
||
if job.output_mode == "asset_generation":
|
||
return "资源生成失败"
|
||
return "生成失败"
|
||
|
||
|
||
async def create_generation_job(
|
||
db: AsyncSession,
|
||
*,
|
||
user_id: str,
|
||
output_mode: str,
|
||
input_type: str,
|
||
request_payload: dict[str, Any],
|
||
story_id: int | None = None,
|
||
) -> GenerationJob:
|
||
"""Create a generation job and record its first event."""
|
||
|
||
job = GenerationJob(
|
||
user_id=user_id,
|
||
story_id=story_id,
|
||
output_mode=output_mode,
|
||
input_type=input_type,
|
||
status="running",
|
||
current_step="request_accepted",
|
||
request_payload=request_payload,
|
||
result_snapshot={},
|
||
)
|
||
db.add(job)
|
||
await db.flush()
|
||
await record_generation_event(
|
||
db,
|
||
job=job,
|
||
story_id=story_id,
|
||
event_type="request_accepted",
|
||
status="succeeded",
|
||
message="Generation request accepted.",
|
||
metadata={"output_mode": output_mode, "input_type": input_type},
|
||
commit=False,
|
||
)
|
||
await db.commit()
|
||
await db.refresh(job)
|
||
return job
|
||
|
||
|
||
async def record_generation_event(
|
||
db: AsyncSession,
|
||
*,
|
||
job: GenerationJob,
|
||
event_type: str,
|
||
status: str,
|
||
story_id: int | None = None,
|
||
message: str | None = None,
|
||
metadata: dict[str, Any] | None = None,
|
||
commit: bool = True,
|
||
) -> GenerationJobEvent:
|
||
"""Append one event to an existing generation job."""
|
||
|
||
job.current_step = event_type
|
||
if story_id is not None:
|
||
job.story_id = story_id
|
||
|
||
event = GenerationJobEvent(
|
||
job_id=job.id,
|
||
story_id=story_id if story_id is not None else job.story_id,
|
||
event_type=event_type,
|
||
status=status,
|
||
message=message,
|
||
event_metadata=metadata or {},
|
||
)
|
||
db.add(event)
|
||
if commit:
|
||
await db.commit()
|
||
return event
|
||
|
||
|
||
async def claim_generation_job_for_worker(
|
||
db: AsyncSession,
|
||
*,
|
||
job_id: str,
|
||
) -> GenerationJob | None:
|
||
"""Claim one queued generation job for worker execution once."""
|
||
|
||
claim_result = await db.execute(
|
||
update(GenerationJob)
|
||
.where(
|
||
GenerationJob.id == job_id,
|
||
GenerationJob.status == "running",
|
||
GenerationJob.current_step.in_(["request_accepted", "retry_queued"]),
|
||
)
|
||
.values(current_step="worker_started")
|
||
)
|
||
await db.commit()
|
||
|
||
if not claim_result.rowcount:
|
||
return None
|
||
|
||
result = await db.execute(select(GenerationJob).where(GenerationJob.id == job_id))
|
||
job = result.scalar_one_or_none()
|
||
if job is None:
|
||
return None
|
||
|
||
await record_generation_event(
|
||
db,
|
||
job=job,
|
||
event_type="worker_started",
|
||
status="running",
|
||
message="Generation worker started processing the accepted request.",
|
||
)
|
||
return job
|
||
|
||
|
||
async def finish_generation_job(
|
||
db: AsyncSession,
|
||
*,
|
||
job: GenerationJob,
|
||
story: Story | None,
|
||
status: str | None = None,
|
||
current_step: str,
|
||
error_message: str | None = None,
|
||
message: str | None = None,
|
||
metadata: dict[str, Any] | None = None,
|
||
) -> GenerationJob:
|
||
"""Mark a generation job as completed/degraded/failed and append a final event."""
|
||
|
||
job.story_id = story.id if story is not None else job.story_id
|
||
job.status = status or (_job_status_from_story(story) if story is not None else "failed")
|
||
job.current_step = current_step
|
||
job.error_message = error_message
|
||
job.result_snapshot = _story_snapshot(story)
|
||
await record_generation_event(
|
||
db,
|
||
job=job,
|
||
story_id=job.story_id,
|
||
event_type=current_step,
|
||
status=job.status,
|
||
message=message,
|
||
metadata={
|
||
**(metadata or {}),
|
||
"result_snapshot": job.result_snapshot,
|
||
},
|
||
commit=False,
|
||
)
|
||
await db.commit()
|
||
await db.refresh(job)
|
||
return job
|
||
|
||
|
||
def generation_event_to_response(event: GenerationJobEvent) -> dict[str, Any]:
|
||
"""Convert a generation event ORM object to an API response dict."""
|
||
|
||
return {
|
||
"id": event.id,
|
||
"job_id": event.job_id,
|
||
"story_id": event.story_id,
|
||
"event_type": event.event_type,
|
||
"status": event.status,
|
||
"message": event.message,
|
||
"event_metadata": event.event_metadata or {},
|
||
"created_at": event.created_at,
|
||
}
|
||
|
||
|
||
def generation_job_to_summary(job: GenerationJob) -> dict[str, Any]:
|
||
"""Convert a generation job ORM object to an API summary dict."""
|
||
|
||
progress = _job_progress(job)
|
||
return {
|
||
"id": job.id,
|
||
"story_id": job.story_id,
|
||
"output_mode": job.output_mode,
|
||
"input_type": job.input_type,
|
||
"status": job.status,
|
||
"current_step": job.current_step,
|
||
**progress,
|
||
"can_cancel": generation_job_can_cancel(job),
|
||
"can_retry": generation_job_can_retry(job),
|
||
"result_snapshot": job.result_snapshot or {},
|
||
"error_message": job.error_message,
|
||
"created_at": job.created_at,
|
||
"updated_at": job.updated_at,
|
||
}
|
||
|
||
|
||
async def get_generation_job_for_user(
|
||
db: AsyncSession,
|
||
*,
|
||
job_id: str,
|
||
user_id: str,
|
||
) -> GenerationJob:
|
||
"""Load one generation job owned by the current user."""
|
||
|
||
result = await db.execute(
|
||
select(GenerationJob).where(
|
||
GenerationJob.id == job_id,
|
||
GenerationJob.user_id == user_id,
|
||
)
|
||
)
|
||
job = result.scalar_one_or_none()
|
||
if job is None:
|
||
raise HTTPException(status_code=404, detail="Generation job not found")
|
||
return job
|
||
|
||
|
||
async def request_generation_job_cancel(
|
||
db: AsyncSession,
|
||
*,
|
||
job_id: str,
|
||
user_id: str,
|
||
) -> dict[str, Any]:
|
||
"""Request cancellation for one queued/running generation job."""
|
||
|
||
job = await get_generation_job_for_user(db, job_id=job_id, user_id=user_id)
|
||
|
||
if not _job_supports_queue_control(job):
|
||
raise HTTPException(status_code=409, detail="当前任务不支持取消")
|
||
|
||
if job.status == "canceled":
|
||
return generation_job_to_summary(job)
|
||
|
||
if _is_terminal_status(job.status):
|
||
raise HTTPException(status_code=409, detail="当前任务已终止,无法取消")
|
||
|
||
if job.current_step == "cancel_requested":
|
||
return generation_job_to_summary(job)
|
||
|
||
if job.current_step in {"request_accepted", "retry_queued"}:
|
||
story = None
|
||
if job.story_id is not None:
|
||
story = (
|
||
await db.execute(
|
||
select(Story).where(
|
||
Story.id == job.story_id,
|
||
Story.user_id == job.user_id,
|
||
)
|
||
)
|
||
).scalar_one_or_none()
|
||
|
||
await finish_generation_job(
|
||
db,
|
||
job=job,
|
||
story=story,
|
||
status="canceled",
|
||
current_step="generation_canceled",
|
||
error_message="Generation canceled by user before worker execution started.",
|
||
message="Generation job was canceled before worker execution started.",
|
||
)
|
||
return generation_job_to_summary(job)
|
||
|
||
previous_step = job.current_step
|
||
job.error_message = "Cancellation requested by user."
|
||
await record_generation_event(
|
||
db,
|
||
job=job,
|
||
story_id=job.story_id,
|
||
event_type="cancel_requested",
|
||
status="running",
|
||
message="Cancellation requested; worker will stop at the next safe checkpoint.",
|
||
metadata={"requested_from_step": previous_step},
|
||
commit=False,
|
||
)
|
||
await db.commit()
|
||
await db.refresh(job)
|
||
return generation_job_to_summary(job)
|
||
|
||
|
||
async def get_generation_job_detail(
|
||
db: AsyncSession,
|
||
*,
|
||
job_id: str,
|
||
user_id: str,
|
||
) -> dict[str, Any]:
|
||
"""Return a user-owned generation job with its ordered event stream."""
|
||
|
||
result = await db.execute(
|
||
select(GenerationJob).where(
|
||
GenerationJob.id == job_id,
|
||
GenerationJob.user_id == user_id,
|
||
)
|
||
)
|
||
job = result.scalar_one_or_none()
|
||
if job is None:
|
||
raise HTTPException(status_code=404, detail="Generation job not found")
|
||
|
||
events = (
|
||
await db.execute(
|
||
select(GenerationJobEvent)
|
||
.where(GenerationJobEvent.job_id == job.id)
|
||
.order_by(GenerationJobEvent.id)
|
||
)
|
||
).scalars().all()
|
||
|
||
return {
|
||
**generation_job_to_summary(job),
|
||
"request_payload": job.request_payload or {},
|
||
"events": [generation_event_to_response(event) for event in events],
|
||
}
|
||
|
||
|
||
async def list_story_generation_jobs(
|
||
db: AsyncSession,
|
||
*,
|
||
story_id: int,
|
||
user_id: str,
|
||
) -> list[dict[str, Any]]:
|
||
"""Return recent generation jobs for a user-owned story."""
|
||
|
||
jobs = (
|
||
await db.execute(
|
||
select(GenerationJob)
|
||
.where(
|
||
GenerationJob.story_id == story_id,
|
||
GenerationJob.user_id == user_id,
|
||
)
|
||
.order_by(desc(GenerationJob.created_at), desc(GenerationJob.id))
|
||
)
|
||
).scalars().all()
|
||
return [generation_job_to_summary(job) for job in jobs]
|
||
|
||
|
||
async def get_active_story_generation_job(
|
||
db: AsyncSession,
|
||
*,
|
||
story_id: int,
|
||
user_id: str,
|
||
) -> GenerationJob | None:
|
||
"""Return the most recent running job for a story, if any."""
|
||
|
||
result = await db.execute(
|
||
select(GenerationJob)
|
||
.where(
|
||
GenerationJob.story_id == story_id,
|
||
GenerationJob.user_id == user_id,
|
||
GenerationJob.status == "running",
|
||
)
|
||
.order_by(desc(GenerationJob.updated_at), desc(GenerationJob.id))
|
||
.limit(1)
|
||
)
|
||
return result.scalar_one_or_none()
|
||
|
||
|
||
async def ensure_no_active_story_generation_job(
|
||
db: AsyncSession,
|
||
*,
|
||
story_id: int,
|
||
user_id: str,
|
||
) -> None:
|
||
"""Prevent duplicate asset work while a story already has a running job."""
|
||
|
||
active_job = await get_active_story_generation_job(db, story_id=story_id, user_id=user_id)
|
||
if active_job is None:
|
||
return
|
||
|
||
progress = _job_progress(active_job)
|
||
raise HTTPException(
|
||
status_code=409,
|
||
detail=(
|
||
f"当前故事已有运行中的任务({progress['progress_label']}),"
|
||
"请等待当前任务完成后再试。"
|
||
),
|
||
)
|
||
|
||
|
||
def _as_float(value: Any) -> float | None:
|
||
if isinstance(value, int | float):
|
||
return float(value)
|
||
return None
|
||
|
||
|
||
def _aggregate_provider_events(
|
||
events: list[GenerationJobEvent],
|
||
*,
|
||
capability: str | None = None,
|
||
) -> dict[str, Any]:
|
||
"""Aggregate provider telemetry from provider call events."""
|
||
|
||
by_key: dict[tuple[str, str], dict[str, Any]] = {}
|
||
failure_reasons: dict[str, int] = {}
|
||
total_latency = 0.0
|
||
latency_count = 0
|
||
total_cost = 0.0
|
||
successful_calls = 0
|
||
failed_calls = 0
|
||
|
||
for event in events:
|
||
metadata = event.event_metadata or {}
|
||
event_capability = str(metadata.get("capability") or "unknown")
|
||
if capability is not None and event_capability != capability:
|
||
continue
|
||
|
||
adapter = str(metadata.get("adapter") or "unknown")
|
||
key = (event_capability, adapter)
|
||
bucket = by_key.setdefault(
|
||
key,
|
||
{
|
||
"capability": event_capability,
|
||
"adapter": adapter,
|
||
"call_count": 0,
|
||
"success_count": 0,
|
||
"failure_count": 0,
|
||
"latency_total": 0.0,
|
||
"latency_count": 0,
|
||
"estimated_cost_usd": 0.0,
|
||
},
|
||
)
|
||
|
||
bucket["call_count"] += 1
|
||
latency = _as_float(metadata.get("latency_ms"))
|
||
if latency is not None:
|
||
bucket["latency_total"] += latency
|
||
bucket["latency_count"] += 1
|
||
total_latency += latency
|
||
latency_count += 1
|
||
|
||
if event.event_type == "provider_call_succeeded":
|
||
bucket["success_count"] += 1
|
||
successful_calls += 1
|
||
cost = _as_float(metadata.get("estimated_cost_usd")) or 0.0
|
||
bucket["estimated_cost_usd"] += cost
|
||
total_cost += cost
|
||
else:
|
||
bucket["failure_count"] += 1
|
||
failed_calls += 1
|
||
reason = str(metadata.get("error") or "unknown_error")
|
||
failure_reasons[reason] = failure_reasons.get(reason, 0) + 1
|
||
|
||
by_provider = []
|
||
for bucket in by_key.values():
|
||
bucket_latency_count = bucket.pop("latency_count")
|
||
bucket_latency_total = bucket.pop("latency_total")
|
||
by_provider.append(
|
||
{
|
||
**bucket,
|
||
"avg_latency_ms": (
|
||
round(bucket_latency_total / bucket_latency_count, 2)
|
||
if bucket_latency_count
|
||
else None
|
||
),
|
||
"estimated_cost_usd": round(bucket["estimated_cost_usd"], 6),
|
||
}
|
||
)
|
||
|
||
by_provider.sort(
|
||
key=lambda item: (
|
||
str(item["capability"]),
|
||
str(item["adapter"]),
|
||
)
|
||
)
|
||
|
||
return {
|
||
"total_calls": successful_calls + failed_calls,
|
||
"successful_calls": successful_calls,
|
||
"failed_calls": failed_calls,
|
||
"avg_latency_ms": round(total_latency / latency_count, 2) if latency_count else None,
|
||
"estimated_cost_usd": round(total_cost, 6),
|
||
"by_provider": by_provider,
|
||
"failure_reasons": [
|
||
{"reason": reason, "count": count}
|
||
for reason, count in sorted(
|
||
failure_reasons.items(),
|
||
key=lambda item: (-item[1], item[0]),
|
||
)
|
||
],
|
||
}
|
||
|
||
|
||
def _event_matches_capability(
|
||
event: GenerationJobEvent,
|
||
capability: str | None = None,
|
||
) -> bool:
|
||
event_capability = str((event.event_metadata or {}).get("capability") or "unknown")
|
||
return capability is None or event_capability == capability
|
||
|
||
|
||
def _provider_events_query(
|
||
*,
|
||
user_id: str | None = None,
|
||
story_id: int | None = None,
|
||
days: int | None = None,
|
||
):
|
||
query = (
|
||
select(
|
||
GenerationJobEvent,
|
||
GenerationJob.user_id,
|
||
GenerationJob.story_id,
|
||
)
|
||
.join(GenerationJob, GenerationJobEvent.job_id == GenerationJob.id)
|
||
.where(
|
||
GenerationJobEvent.event_type.in_(
|
||
["provider_call_succeeded", "provider_call_failed"]
|
||
),
|
||
)
|
||
)
|
||
|
||
if user_id is not None:
|
||
query = query.where(GenerationJob.user_id == user_id)
|
||
|
||
if story_id is not None:
|
||
query = query.where(GenerationJob.story_id == story_id)
|
||
|
||
if days is not None:
|
||
cutoff = datetime.now(timezone.utc) - timedelta(days=days)
|
||
query = query.where(GenerationJobEvent.created_at >= cutoff)
|
||
|
||
return query.order_by(GenerationJobEvent.id)
|
||
|
||
|
||
async def get_story_provider_stats(
|
||
db: AsyncSession,
|
||
*,
|
||
story_id: int,
|
||
user_id: str,
|
||
days: int | None = None,
|
||
capability: str | None = None,
|
||
) -> dict[str, Any]:
|
||
"""Aggregate provider call telemetry from all user-owned jobs for one story."""
|
||
|
||
events = (
|
||
await db.execute(
|
||
_provider_events_query(
|
||
user_id=user_id,
|
||
story_id=story_id,
|
||
days=days,
|
||
)
|
||
)
|
||
).scalars().all()
|
||
|
||
return {
|
||
"story_id": story_id,
|
||
"window_days": days,
|
||
"capability": capability,
|
||
**_aggregate_provider_events(events, capability=capability),
|
||
}
|
||
|
||
|
||
async def get_user_provider_analytics(
|
||
db: AsyncSession,
|
||
*,
|
||
user_id: str,
|
||
days: int | None = None,
|
||
capability: str | None = None,
|
||
) -> dict[str, Any]:
|
||
"""Aggregate provider telemetry across all stories owned by one user."""
|
||
|
||
events = (
|
||
await db.execute(
|
||
_provider_events_query(
|
||
user_id=user_id,
|
||
days=days,
|
||
)
|
||
)
|
||
).scalars().all()
|
||
filtered_event_job_ids = {
|
||
event.job_id
|
||
for event in events
|
||
if _event_matches_capability(event, capability)
|
||
}
|
||
filtered_story_ids = {
|
||
event.story_id
|
||
for event in events
|
||
if event.story_id is not None and _event_matches_capability(event, capability)
|
||
}
|
||
|
||
return {
|
||
"window_days": days,
|
||
"capability": capability,
|
||
**_aggregate_provider_events(events, capability=capability),
|
||
"job_count": len(filtered_event_job_ids),
|
||
"story_count": len(filtered_story_ids),
|
||
}
|
||
|
||
|
||
async def get_admin_provider_analytics(
|
||
db: AsyncSession,
|
||
*,
|
||
days: int | None = None,
|
||
capability: str | None = None,
|
||
) -> dict[str, Any]:
|
||
"""Aggregate provider telemetry across every user in the current environment."""
|
||
|
||
rows = (await db.execute(_provider_events_query(days=days))).all()
|
||
events = [event for event, _, _ in rows]
|
||
filtered_rows = [
|
||
(event, user_id, story_id)
|
||
for event, user_id, story_id in rows
|
||
if _event_matches_capability(event, capability)
|
||
]
|
||
|
||
by_user: dict[str, dict[str, Any]] = {}
|
||
filtered_job_ids = {event.job_id for event, _, _ in filtered_rows}
|
||
filtered_story_ids = {
|
||
story_id for _, _, story_id in filtered_rows if story_id is not None
|
||
}
|
||
filtered_user_ids = {user_id for _, user_id, _ in filtered_rows}
|
||
|
||
for event, user_id, story_id in filtered_rows:
|
||
bucket = by_user.setdefault(
|
||
user_id,
|
||
{
|
||
"user_id": user_id,
|
||
"call_count": 0,
|
||
"success_count": 0,
|
||
"failure_count": 0,
|
||
"estimated_cost_usd": 0.0,
|
||
"job_ids": set(),
|
||
"story_ids": set(),
|
||
},
|
||
)
|
||
bucket["call_count"] += 1
|
||
bucket["job_ids"].add(event.job_id)
|
||
if story_id is not None:
|
||
bucket["story_ids"].add(story_id)
|
||
|
||
if event.event_type == "provider_call_succeeded":
|
||
bucket["success_count"] += 1
|
||
bucket["estimated_cost_usd"] += (
|
||
_as_float((event.event_metadata or {}).get("estimated_cost_usd")) or 0.0
|
||
)
|
||
else:
|
||
bucket["failure_count"] += 1
|
||
|
||
serialized_users = [
|
||
{
|
||
"user_id": user_id,
|
||
"call_count": bucket["call_count"],
|
||
"success_count": bucket["success_count"],
|
||
"failure_count": bucket["failure_count"],
|
||
"job_count": len(bucket["job_ids"]),
|
||
"story_count": len(bucket["story_ids"]),
|
||
"estimated_cost_usd": round(bucket["estimated_cost_usd"], 6),
|
||
}
|
||
for user_id, bucket in by_user.items()
|
||
]
|
||
serialized_users.sort(
|
||
key=lambda item: (
|
||
-int(item["call_count"]),
|
||
-float(item["estimated_cost_usd"]),
|
||
str(item["user_id"]),
|
||
)
|
||
)
|
||
|
||
return {
|
||
"scope": "current_environment",
|
||
"window_days": days,
|
||
"capability": capability,
|
||
**_aggregate_provider_events(events, capability=capability),
|
||
"user_count": len(filtered_user_ids),
|
||
"job_count": len(filtered_job_ids),
|
||
"story_count": len(filtered_story_ids),
|
||
"by_user": serialized_users,
|
||
}
|
||
|
||
|
||
async def get_user_generation_ops_summary(
|
||
db: AsyncSession,
|
||
*,
|
||
user_id: str,
|
||
hours: int = 24,
|
||
recent_failure_limit: int = 5,
|
||
) -> dict[str, Any]:
|
||
"""Summarize recent generation health for one user."""
|
||
|
||
stale_after_minutes = settings.generation_job_stale_minutes
|
||
recent_cutoff = datetime.now(timezone.utc) - timedelta(hours=hours)
|
||
|
||
running_jobs = (
|
||
await db.execute(
|
||
select(GenerationJob)
|
||
.where(
|
||
GenerationJob.user_id == user_id,
|
||
GenerationJob.status == "running",
|
||
)
|
||
.order_by(desc(GenerationJob.updated_at), desc(GenerationJob.id))
|
||
)
|
||
).scalars().all()
|
||
|
||
recent_jobs = (
|
||
await db.execute(
|
||
select(GenerationJob, Story.title)
|
||
.outerjoin(Story, Story.id == GenerationJob.story_id)
|
||
.where(
|
||
GenerationJob.user_id == user_id,
|
||
GenerationJob.updated_at >= recent_cutoff,
|
||
)
|
||
.order_by(desc(GenerationJob.updated_at), desc(GenerationJob.id))
|
||
)
|
||
).all()
|
||
|
||
recent_failures: list[dict[str, Any]] = []
|
||
failed_jobs = 0
|
||
degraded_jobs = 0
|
||
asset_retry_jobs = 0
|
||
|
||
for job, story_title in recent_jobs:
|
||
if job.status == "failed":
|
||
failed_jobs += 1
|
||
if len(recent_failures) < recent_failure_limit:
|
||
recent_failures.append(
|
||
{
|
||
"job_id": job.id,
|
||
"story_id": job.story_id,
|
||
"story_title": story_title,
|
||
"output_mode": job.output_mode,
|
||
"current_step": job.current_step,
|
||
"error_message": job.error_message,
|
||
"failure_label": _failure_label(job),
|
||
"updated_at": job.updated_at,
|
||
}
|
||
)
|
||
elif job.status == "degraded_completed":
|
||
degraded_jobs += 1
|
||
|
||
if job.output_mode in {"asset_retry", "asset_generation"}:
|
||
asset_retry_jobs += 1
|
||
|
||
return {
|
||
"window_hours": hours,
|
||
"stale_threshold_minutes": stale_after_minutes,
|
||
"active_jobs": len(running_jobs),
|
||
"stale_running_jobs": sum(
|
||
1 for job in running_jobs if _is_stale_job(job, stale_after_minutes=stale_after_minutes)
|
||
),
|
||
"failed_jobs": failed_jobs,
|
||
"degraded_jobs": degraded_jobs,
|
||
"asset_retry_jobs": asset_retry_jobs,
|
||
"recent_failures": recent_failures,
|
||
}
|
||
|
||
|
||
async def mark_stale_generation_jobs(
|
||
db: AsyncSession,
|
||
*,
|
||
stale_after_minutes: int | None = None,
|
||
) -> dict[str, int]:
|
||
"""Mark long-running generation jobs as failed so they no longer appear stuck forever."""
|
||
|
||
threshold = stale_after_minutes or settings.generation_job_stale_minutes
|
||
running_jobs = (
|
||
await db.execute(
|
||
select(GenerationJob)
|
||
.where(GenerationJob.status == "running")
|
||
.order_by(GenerationJob.updated_at, GenerationJob.id)
|
||
)
|
||
).scalars().all()
|
||
|
||
marked_stale = 0
|
||
|
||
for job in running_jobs:
|
||
if not _is_stale_job(job, stale_after_minutes=threshold):
|
||
continue
|
||
|
||
story = None
|
||
if job.story_id is not None:
|
||
story = (
|
||
await db.execute(
|
||
select(Story).where(
|
||
Story.id == job.story_id,
|
||
Story.user_id == job.user_id,
|
||
)
|
||
)
|
||
).scalar_one_or_none()
|
||
|
||
await finish_generation_job(
|
||
db,
|
||
job=job,
|
||
story=story,
|
||
status="failed",
|
||
current_step="generation_stale_failed",
|
||
error_message=f"Generation job exceeded {threshold} minutes without progress.",
|
||
message="Generation job was marked failed after exceeding the stale threshold.",
|
||
metadata={"stale_after_minutes": threshold},
|
||
)
|
||
marked_stale += 1
|
||
logger.warning(
|
||
"generation_job_marked_stale",
|
||
job_id=job.id,
|
||
story_id=job.story_id,
|
||
output_mode=job.output_mode,
|
||
stale_after_minutes=threshold,
|
||
)
|
||
|
||
return {
|
||
"running": len(running_jobs),
|
||
"marked_stale": marked_stale,
|
||
"stale_after_minutes": threshold,
|
||
}
|